Skip to content

Commit

Permalink
Partial implementation of Device<E> for Webgpu
Browse files Browse the repository at this point in the history
  • Loading branch information
favilo committed Nov 27, 2023
1 parent 30c7e48 commit 20925cf
Show file tree
Hide file tree
Showing 42 changed files with 767 additions and 30 deletions.
3 changes: 1 addition & 2 deletions dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", optional = true }
futures-lite = { version = "2.0.1", optional = true }
bytemuck = { version = "1.14.0", optional = true }

[dev-dependencies]
tempfile = "3.3.0"
Expand All @@ -62,7 +61,7 @@ fast-alloc = ["std"]

cuda = ["dep:cudarc", "dep:glob"]
cudnn = ["cuda", "cudarc?/cudnn"]
webgpu = ["dep:wgpu", "dep:futures-lite", "dep:bytemuck", "wgpu/expose-ids"]
webgpu = ["dep:wgpu", "dep:futures-lite", "wgpu/expose-ids"]

f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"]

Expand Down
26 changes: 13 additions & 13 deletions dfdx-core/src/tensor/webgpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub(crate) fn round_to_buffer_alignment(size: u64) -> u64 {
}

impl Webgpu {
fn tensor_from_host_buf<S: Shape, E: Unit + bytemuck::Pod>(
fn tensor_from_host_buf<S: Shape, E: Unit>(
&self,
shape: S,
buf: Vec<E>,
Expand All @@ -28,7 +28,7 @@ impl Webgpu {
Ok(self.build_tensor(shape, shape.strides(), buffer))
}

pub(crate) fn build_tensor<S: Shape, E: Unit + bytemuck::Pod>(
pub(crate) fn build_tensor<S: Shape, E: Unit>(
&self,
shape: S,
strides: S::Concrete,
Expand All @@ -52,42 +52,42 @@ impl Webgpu {
}
}

impl<E: Unit + SafeZeros + From<f32> + bytemuck::Pod> ZerosTensor<E> for Webgpu {
impl<E: Unit + SafeZeros + From<bool>> ZerosTensor<E> for Webgpu {
fn try_zeros_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let strides = shape.strides();
let data = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
data.copy_to_device(
&self.dev,
&self.queue,
&vec![E::from(0.0); shape.num_elements()],
&vec![E::from(false); shape.num_elements()],
);

Ok(self.build_tensor(shape, strides, data))
}
}

impl<E: Unit + SafeZeros + From<f32> + bytemuck::Pod> ZeroFillStorage<E> for Webgpu {
impl<E: Unit + SafeZeros + From<bool>> ZeroFillStorage<E> for Webgpu {
fn try_fill_with_zeros(&self, storage: &mut Self::Vec) -> Result<(), Error> {
storage.copy_to_device(
&self.dev,
&self.queue,
&vec![E::from(0.0); storage.size() as usize / std::mem::size_of::<E>()],
&vec![E::from(false); storage.size() as usize / std::mem::size_of::<E>()],
);

Ok(())
}
}

impl<E: Unit + bytemuck::Pod> OnesTensor<E> for Webgpu {
impl<E: Unit> OnesTensor<E> for Webgpu {
fn try_ones_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let buf = std::vec![E::ONE; shape.num_elements()];
self.tensor_from_host_buf(shape, buf)
}
}

impl<E: Unit + bytemuck::Pod> TriangleTensor<E> for Webgpu
impl<E: Unit> TriangleTensor<E> for Webgpu
where
Cpu: TriangleTensor<E>,
{
Expand Down Expand Up @@ -118,7 +118,7 @@ where
}
}

impl<E: Unit + bytemuck::Pod> OneFillStorage<E> for Webgpu {
impl<E: Unit> OneFillStorage<E> for Webgpu {
fn try_fill_with_ones(&self, storage: &mut Self::Vec) -> Result<(), Error> {
let len = storage.size() as usize / std::mem::size_of::<E>();
let buf = std::vec![E::ONE; len];
Expand All @@ -130,7 +130,7 @@ impl<E: Unit + bytemuck::Pod> OneFillStorage<E> for Webgpu {
}
}

impl<E: Unit + bytemuck::Pod> SampleTensor<E> for Webgpu
impl<E: Unit> SampleTensor<E> for Webgpu
where
Cpu: SampleTensor<E>,
{
Expand Down Expand Up @@ -176,7 +176,7 @@ where
}
}

impl<E: Unit + bytemuck::Pod> CopySlice<E> for Webgpu {
impl<E: Unit> CopySlice<E> for Webgpu {
fn copy_from<S: Shape, T>(dst: &mut Tensor<S, E, Self, T>, src: &[E]) {
assert_eq!(
dst.data.size() as usize,
Expand All @@ -200,7 +200,7 @@ impl<E: Unit + bytemuck::Pod> CopySlice<E> for Webgpu {
}
}

impl<E: Unit + bytemuck::Pod> TensorFromVec<E> for Webgpu {
impl<E: Unit> TensorFromVec<E> for Webgpu {
fn try_tensor_from_vec<S: Shape>(
&self,
src: Vec<E>,
Expand All @@ -216,7 +216,7 @@ impl<E: Unit + bytemuck::Pod> TensorFromVec<E> for Webgpu {
}
}

impl<S: Shape, E: Unit + bytemuck::Pod> TensorToArray<S, E> for Webgpu
impl<S: Shape, E: Unit> TensorToArray<S, E> for Webgpu
where
Cpu: TensorToArray<S, E> + Storage<E>,
{
Expand Down
25 changes: 10 additions & 15 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,14 @@ impl Buffer {
self.data.size() as usize
}

pub(crate) fn copy_to_device<E: Unit + bytemuck::Pod>(
&self,
dev: &Device,
queue: &Queue,
slice: &[E],
) {
let slice = bytemuck::cast_slice(slice);
pub(crate) fn copy_to_device<E: Unit>(&self, dev: &Device, queue: &Queue, slice: &[E]) {
let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, self.size()) };
queue.write_buffer(&self.data, 0, slice);
queue.submit(std::iter::empty());
dev.poll(Maintain::Wait);
}

pub(crate) fn copy_to_host<E: Unit + bytemuck::Pod>(
&self,
dev: &Device,
queue: &Queue,
buf: &mut [E],
) {
pub(crate) fn copy_to_host<E: Unit>(&self, dev: &Device, queue: &Queue, buf: &mut [E]) {
let (sender, receiver) = std::sync::mpsc::channel();
let buffer = dev.create_buffer(&BufferDescriptor {
label: None,
Expand All @@ -86,7 +76,12 @@ impl Buffer {
let _ = receiver.recv().unwrap();
let data = slice.get_mapped_range();
// TODO: How are we sure this is safe?
let slice = bytemuck::cast_slice(&*data);
let slice = unsafe {
std::slice::from_raw_parts(
data.as_ptr() as *const E,
self.size() / std::mem::size_of::<E>(),
)
};
buf.copy_from_slice(slice);
}
}
Expand Down Expand Up @@ -299,7 +294,7 @@ impl Synchronize for Webgpu {
}
}

impl<E: Unit + bytemuck::Pod> Storage<E> for Webgpu {
impl<E: Unit> Storage<E> for Webgpu {
type Vec = CachableBuffer<E>;

fn try_alloc_len(&self, len: usize) -> Result<Self::Vec, Error> {
Expand Down
3 changes: 3 additions & 0 deletions dfdx-core/src/tensor_ops/adam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;

#[cfg(feature = "webgpu")]
mod webgpu_kernel;

use crate::{
shapes::{Dtype, Shape},
tensor::{Error, Storage, Tensor},
Expand Down
15 changes: 15 additions & 0 deletions dfdx-core/src/tensor_ops/adam/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::prelude::{Dtype, Webgpu};

impl<E: Dtype> super::AdamKernel<E> for Webgpu {
fn adam_kernel(
&self,
t: i32,
cfg: &crate::prelude::AdamConfig,
param: &mut Self::Vec,
moment1: &mut Self::Vec,
moment2: &mut Self::Vec,
grad: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
3 changes: 3 additions & 0 deletions dfdx-core/src/tensor_ops/add/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;

#[cfg(feature = "webgpu")]
mod webgpu_kernel;

use super::ops::*;
use crate::{
shapes::*,
Expand Down
57 changes: 57 additions & 0 deletions dfdx-core/src/tensor_ops/add/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
extern crate alloc;
use alloc::borrow::Cow;

use crate::prelude::{
ops::{BinaryKernel, UnaryKernel},
Dtype, Webgpu,
};

impl<E: Dtype> UnaryKernel<super::ScalarAddKernelOp<E>, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = false;

const BACKWARD_WITHOUT_DATA: bool = true;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::ScalarAddKernelOp<E>,
inp: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::ScalarAddKernelOp<E>,
inp: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_inp: &mut Self::Vec,
out: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}

impl<E: Dtype> BinaryKernel<super::BinaryAddKernelOp, E> for Webgpu {
const BACKWARD_WITHOUT_DATA: bool = true;

fn forward<S: crate::prelude::Shape>(
&self,
op: super::BinaryAddKernelOp,
lhs: Cow<crate::prelude::Tensor<S, E, Self>>,
rhs: Cow<crate::prelude::Tensor<S, E, Self>>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
op: super::BinaryAddKernelOp,
lhs: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_lhs: &mut Self::Vec,
rhs: &impl crate::prelude::Tensorlike<S, E, Self>,
grad_rhs: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
3 changes: 3 additions & 0 deletions dfdx-core/src/tensor_ops/choose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;

#[cfg(feature = "webgpu")]
mod webgpu_kernel;

use crate::{
shapes::{Dtype, HasShape, Shape},
tensor::{Error, Merge, PutTape, SplitTape, Storage, Tape, Tensor},
Expand Down
24 changes: 24 additions & 0 deletions dfdx-core/src/tensor_ops/choose/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use crate::prelude::{Dtype, Webgpu};

impl<E: Dtype> super::ChooseKernel<E> for Webgpu {
fn forward<S: crate::prelude::Shape>(
&self,
cond: &crate::prelude::Tensor<S, bool, Self>,
lhs: &crate::prelude::Tensor<S, E, Self>,
rhs: &crate::prelude::Tensor<S, E, Self>,
) -> Result<crate::prelude::Tensor<S, E, Self>, crate::prelude::Error> {
todo!()
}

fn backward<S: crate::prelude::Shape>(
&self,
cond: &crate::prelude::Tensor<S, bool, Self>,
lhs: &crate::prelude::Tensor<S, E, Self>,
grad_lhs: &mut <Self as crate::prelude::Storage<E>>::Vec,
rhs: &crate::prelude::Tensor<S, E, Self>,
grad_rhs: &mut <Self as crate::prelude::Storage<E>>::Vec,
grad_out: &<Self as crate::prelude::Storage<E>>::Vec,
) -> Result<(), crate::prelude::Error> {
todo!()
}
}
3 changes: 3 additions & 0 deletions dfdx-core/src/tensor_ops/concat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;

#[cfg(feature = "webgpu")]
mod webgpu_kernel;

/// Concatenate two tensors along the first dimension.
///
/// **Pytorch equivalent** `torch.concat`.
Expand Down
25 changes: 25 additions & 0 deletions dfdx-core/src/tensor_ops/concat/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::{shapes::*, tensor::*};

use super::ConcatShape;

impl<E: Dtype> super::ConcatKernel<E> for Webgpu {
fn forward<A: Shape, B: Shape>(
&self,
a: &Tensor<A, E, Self>,
b: &Tensor<B, E, Self>,
) -> Result<Tensor<A::Catted, E, Self>, Error>
where
A: ConcatShape<B>,
{
todo!()
}

fn backward(
&self,
grad_a: &mut Self::Vec,
grad_b: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Error> {
todo!()
}
}
2 changes: 2 additions & 0 deletions dfdx-core/src/tensor_ops/concat_along/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::{shapes::*, tensor::*};
mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;
#[cfg(feature = "webgpu")]
mod webgpu_kernel;

/// Concatenate two tensors along a given axis.
///
Expand Down
25 changes: 25 additions & 0 deletions dfdx-core/src/tensor_ops/concat_along/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::{shapes::*, tensor::*};

impl<E: Dtype> super::ConcatAlongKernel<E> for Webgpu {
fn forward<A: Shape, B: Shape, C: Shape>(
&self,
ax: usize,
a: &Tensor<A, E, Self>,
b: &Tensor<B, E, Self>,
c: &mut Tensor<C, E, Self>,
) -> Result<(), Error> {
todo!()
}

fn backward<A: Shape, B: Shape>(
&self,
ax: usize,
a: &GhostTensor<A, E, Self>,
grad_a: &mut Self::Vec,
b: &GhostTensor<B, E, Self>,
grad_b: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Error> {
todo!()
}
}
3 changes: 3 additions & 0 deletions dfdx-core/src/tensor_ops/div/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;

#[cfg(feature = "webgpu")]
mod webgpu_kernel;

use super::ops::*;
use crate::{shapes::*, tensor::*};

Expand Down
Loading

0 comments on commit 20925cf

Please sign in to comment.