Skip to content

Commit

Permalink
Implement abs kernel, and use broken unary operation for all the comp…
Browse files Browse the repository at this point in the history
…iler errors
  • Loading branch information
favilo committed Dec 3, 2023
1 parent 837a265 commit 167ee4b
Show file tree
Hide file tree
Showing 26 changed files with 424 additions and 561 deletions.
67 changes: 62 additions & 5 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
Adapter, BufferDescriptor, BufferUsages, Device, Instance, InstanceDescriptor, Maintain, Queue,
RequestDeviceError,
RequestDeviceError, ShaderModule, ShaderModuleDescriptor,
};

use crate::{
Expand All @@ -14,10 +15,16 @@ use crate::{
#[cfg(feature = "no-std")]
use spin::Mutex;

use core::any::TypeId;
#[cfg(not(feature = "no-std"))]
use std::sync::Mutex;

use std::{marker::PhantomData, sync::Arc, vec::Vec};
use std::{
collections::HashMap,
marker::PhantomData,
sync::{Arc, RwLock},
vec::Vec,
};

use super::allocate::round_to_buffer_alignment;

Expand All @@ -40,12 +47,16 @@ impl Buffer {
self.size
}

pub(crate) fn len<E: Unit>(&self) -> usize {
self.size / std::mem::size_of::<E>()
}

#[allow(unused)]
pub(crate) fn capacity(&self) -> usize {
self.data.size() as usize
}

pub(crate) fn copy_to_device<E: Unit>(&self, dev: &Device, queue: &Queue, slice: &[E]) {
pub(crate) fn copy_to_device<E>(&self, dev: &Device, queue: &Queue, slice: &[E]) {
let slice = unsafe {
std::slice::from_raw_parts(
slice.as_ptr() as *const u8,
Expand Down Expand Up @@ -102,6 +113,7 @@ pub struct Webgpu {
pub(crate) queue: Arc<Queue>,

pub(crate) cache: Arc<TensorCache<Buffer>>,
pub(crate) cs_cache: Arc<RwLock<HashMap<TypeId, Arc<ShaderModule>>>>,
}

impl From<RequestDeviceError> for Error {
Expand Down Expand Up @@ -147,18 +159,19 @@ impl Webgpu {
queue,

cache: Default::default(),
cs_cache: Default::default(),
})
}
}

impl Webgpu {
pub(crate) unsafe fn alloc_empty<E>(&self, len: usize) -> Result<Buffer, Error> {
pub(crate) fn alloc_empty<E>(&self, len: usize) -> Result<Buffer, Error> {
let data = self.cache.try_pop::<E>(len).map_or_else(
|| Buffer {
data: self.dev.create_buffer(&BufferDescriptor {
label: None,
size: round_to_buffer_alignment((len * std::mem::size_of::<E>()) as u64),
usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
mapped_at_creation: false,
}),
size: len * std::mem::size_of::<E>(),
Expand All @@ -168,6 +181,50 @@ impl Webgpu {
Ok(data)
}

pub(crate) fn alloc_init<E>(&self, init: &[E]) -> Result<Buffer, Error> {
let data = self.cache.try_pop::<E>(init.len()).map_or_else(
|| {
let contents = unsafe {
std::slice::from_raw_parts(
init.as_ptr() as *const u8,
init.len() * std::mem::size_of::<E>(),
)
};
Buffer {
data: self.dev.create_buffer_init(&BufferInitDescriptor {
label: None,
usage: BufferUsages::STORAGE
| BufferUsages::COPY_SRC
| BufferUsages::COPY_DST,
contents,
}),
size: init.len() * std::mem::size_of::<E>(),
}
},
|bfr| {
bfr.copy_to_device::<E>(&self.dev, &self.queue, init);
bfr
},
);
Ok(data)
}

pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool {
self.cs_cache.read().unwrap().contains_key(&name)
}

pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) {
let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(source.into()),
}));
self.cs_cache.write().unwrap().insert(name, module);
}

pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
self.cs_cache.read().unwrap().get(&name).cloned()
}

// #[allow(unused)]
// pub(crate) unsafe fn get_workspace<E>(&self, len: usize) -> Result<MutexGuard<Buffer>, Error> {
// let num_bytes_required = len * std::mem::size_of::<E>();
Expand Down
40 changes: 40 additions & 0 deletions dfdx-core/src/tensor_ops/abs/abs.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
struct AbsKernelOp {};

@group(0)
@binding(0)
var<storage, read> op: AbsKernelOp;

@group(0)
@binding(1)
var<storage, read> inp: array<f32>;

@group(0)
@binding(2)
var<storage, read_write> out: array<f32>;

@group(0)
@binding(3)
var<storage, read_write> inp_grad: array<f32>;

@group(0)
@binding(4)
var<storage, read_write> out_grad: array<f32>;

@compute
@workgroup_size(1)
fn abs_fwd_f32(@builtin(global_invocation_id) global_id: vec3<u32>) {
var<private> x = if arrayLength(inp) > 0 { &inp[global_id] } else { &out[global_id] };
*x = abs(*x);
}

@compute
@workgroup_size(1)
fn abs_bwd_f32(@builtin(global_invocation_id) global_id: vec3<u32>) {
// Not needed for Abs, but if we can figure out a template system, we can leave it all in.
// let x = if arrayLength(inp) > 0 { inp[global_id] } else { 0.0 };
// let y = if arrayLength(out) > 0 { out[global_id] } else { 0.0 };
var<private> dx: f32;
dx = sign(inp[global_id]);

inp_grad[global_id] += dx * out_grad[global_id];
}
30 changes: 4 additions & 26 deletions dfdx-core/src/tensor_ops/abs/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,6 @@
use std::borrow::Cow;
use super::AbsKernelOp;
use crate::tensor_ops::webgpu_kernels::webgpu_unary;

use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu};
const WGSL: &str = include_str!("abs.wgsl");

impl<E: Dtype> UnaryKernel<super::AbsKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = false;

const BACKWARD_WITHOUT_DATA: bool = false;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::AbsKernelOp,
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::AbsKernelOp,
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec,
out: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
webgpu_unary!(AbsKernelOp, f32, WGSL, "abs_fwd_f32", "abs_bwd_f32");
35 changes: 9 additions & 26 deletions dfdx-core/src/tensor_ops/accurate_gelu/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
use std::borrow::Cow;
use crate::prelude::webgpu_kernels::webgpu_unary;

use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu};
const WGSL: &str = "TODO";

impl<E: Dtype> UnaryKernel<super::AccurateGeLUKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = false;

const BACKWARD_WITHOUT_DATA: bool = false;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::AccurateGeLUKernelOp,
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::AccurateGeLUKernelOp,
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec,
out: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
webgpu_unary!(
super::AccurateGeLUKernelOp,
f32,
WGSL,
"gelu_fwd_f32",
"gelu_bwd_f32"
);
27 changes: 4 additions & 23 deletions dfdx-core/src/tensor_ops/add/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,15 @@
use super::{BinaryAddKernelOp as Binary, ScalarAddKernelOp as Scalar};
use std::borrow::Cow;

use crate::prelude::{
ops::{BinaryKernel, UnaryKernel},
webgpu_kernels::webgpu_unary,
Dtype, Webgpu,
};

impl<E: Dtype> UnaryKernel<super::ScalarAddKernelOp<E>, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = false;
const WGSL: &str = "TODO";

const BACKWARD_WITHOUT_DATA: bool = true;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::ScalarAddKernelOp<E>,
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::ScalarAddKernelOp<E>,
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec,
out: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
webgpu_unary!(Scalar<f32>, f32, WGSL, "scalar_fwd_f32", "scalar_bwd_f32");

impl<E: Dtype> BinaryKernel<super::BinaryAddKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_DATA: bool = true;
Expand Down
35 changes: 9 additions & 26 deletions dfdx-core/src/tensor_ops/clamp/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
use std::borrow::Cow;
use crate::prelude::webgpu_kernels::webgpu_unary;

use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu};
const WGSL: &str = "TODO";

impl<E: Dtype> UnaryKernel<super::ClampKernelOp<E>, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = false;

const BACKWARD_WITHOUT_DATA: bool = false;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::ClampKernelOp<E>,
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::ClampKernelOp<E>,
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec,
out: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
webgpu_unary!(
super::ClampKernelOp<f32>,
f32,
WGSL,
"clamp_fwd_f32",
"clamp_bwd_f32"
);
29 changes: 3 additions & 26 deletions dfdx-core/src/tensor_ops/cos/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,5 @@
use std::borrow::Cow;
use crate::prelude::webgpu_kernels::webgpu_unary;

use crate::prelude::{ops::UnaryKernel, Dtype, Webgpu};
const WGSL: &str = "TODO";

impl<E: Dtype> UnaryKernel<super::CosKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = false;

const BACKWARD_WITHOUT_DATA: bool = false;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::CosKernelOp,
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::CosKernelOp,
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec,
out: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
webgpu_unary!(super::CosKernelOp, f32, WGSL, "cos_fwd_f32", "cos_bwd_f32");
31 changes: 4 additions & 27 deletions dfdx-core/src/tensor_ops/div/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,11 @@
use super::{BinaryDivKernelOp as Binary, ScalarDivKernelOp as Scalar};
use std::borrow::Cow;

use crate::prelude::{
ops::{BinaryKernel, UnaryKernel},
Dtype, Webgpu,
};
use crate::prelude::{ops::BinaryKernel, webgpu_kernels::webgpu_unary, Dtype, Webgpu};

impl<E: Dtype> UnaryKernel<super::ScalarDivKernelOp<E>, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = false;
const WGSL: &str = "TODO";

const BACKWARD_WITHOUT_DATA: bool = true;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::ScalarDivKernelOp<E>,
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::ScalarDivKernelOp<E>,
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec,
out: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
webgpu_unary!(const_df() Scalar<f32>, f32, WGSL, "scalar_sub_fwd", "scalar_sub_bwd");

impl<E: Dtype> BinaryKernel<super::BinaryDivKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_DATA: bool = true;
Expand Down
Loading

0 comments on commit 167ee4b

Please sign in to comment.