Skip to content

Commit

Permalink
Fixing no-std stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
favilo committed Nov 28, 2023
1 parent 3b25249 commit 87cb8f3
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
3 changes: 2 additions & 1 deletion dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", optional = true }
futures-lite = { version = "2.0.1", optional = true }
thingbuf = { version = "0.1.4", optional = true }

[dev-dependencies]
tempfile = "3.3.0"
Expand All @@ -61,7 +62,7 @@ fast-alloc = ["std"]

cuda = ["dep:cudarc", "dep:glob"]
cudnn = ["cuda", "cudarc?/cudnn"]
webgpu = ["dep:wgpu", "dep:futures-lite", "wgpu/expose-ids"]
webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "wgpu/expose-ids"]

f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"]

Expand Down
8 changes: 4 additions & 4 deletions dfdx-core/src/tensor/webgpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl<E: Unit + SafeZeros> ZeroFillStorage<E> for Webgpu {
impl<E: Unit> OnesTensor<E> for Webgpu {
fn try_ones_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let buf = std::vec![E::ONE; shape.num_elements()];
let buf = vec![E::ONE; shape.num_elements()];
self.tensor_from_host_buf(shape, buf)
}
}
Expand All @@ -90,7 +90,7 @@ where
diagonal: impl Into<Option<isize>>,
) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let mut data = std::vec![val; shape.num_elements()];
let mut data = vec![val; shape.num_elements()];
let offset = diagonal.into().unwrap_or(0);
triangle_mask(&mut data, &shape, true, offset);
self.tensor_from_host_buf(shape, data)
Expand All @@ -103,7 +103,7 @@ where
diagonal: impl Into<Option<isize>>,
) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let mut data = std::vec![val; shape.num_elements()];
let mut data = vec![val; shape.num_elements()];
let offset = diagonal.into().unwrap_or(0);
triangle_mask(&mut data, &shape, false, offset);
self.tensor_from_host_buf(shape, data)
Expand All @@ -113,7 +113,7 @@ where
impl<E: Unit> OneFillStorage<E> for Webgpu {
fn try_fill_with_ones(&self, storage: &mut Self::Vec) -> Result<(), Error> {
let len = storage.size() as usize / std::mem::size_of::<E>();
let buf = std::vec![E::ONE; len];
let buf = vec![E::ONE; len];
storage
.data
.copy_to_device::<E>(&self.dev, &self.queue, &buf);
Expand Down
22 changes: 16 additions & 6 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ use crate::{
},
};

#[cfg(feature = "no-std")]
use spin::Mutex;

#[cfg(not(feature = "no-std"))]
use std::sync::Mutex;

use std::{marker::PhantomData, sync::Arc, vec::Vec};

use super::allocate::round_to_buffer_alignment;
Expand Down Expand Up @@ -52,7 +58,7 @@ impl Buffer {
}

pub(crate) fn copy_to_host<E: Unit>(&self, dev: &Device, queue: &Queue, buf: &mut [E]) {
let (sender, receiver) = std::sync::mpsc::channel();
let (sender, receiver) = thingbuf::mpsc::channel(1);
let buffer = dev.create_buffer(&BufferDescriptor {
label: None,
size: self.size() as u64,
Expand All @@ -66,11 +72,11 @@ impl Buffer {
}
let slice = buffer.slice(..self.size() as u64);
slice.map_async(wgpu::MapMode::Read, move |_| {
sender.send(()).unwrap();
futures_lite::future::block_on(sender.send(())).unwrap();
});
dev.poll(Maintain::Wait);

let _ = receiver.recv().unwrap();
let _ = futures_lite::future::block_on(receiver.recv());
let data = slice.get_mapped_range();
// TODO: How are we sure this is safe?
let slice = unsafe {
Expand Down Expand Up @@ -110,15 +116,19 @@ impl Default for Webgpu {
}
}

static CONSTRUCTOR_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
static CONSTRUCTOR_MUTEX: Mutex<()> = Mutex::new(());

impl Webgpu {
pub fn seed_from_u64(seed: u64) -> Self {
Self::try_build(seed).unwrap()
}

pub fn try_build(seed: u64) -> Result<Self, Error> {
let _lock = CONSTRUCTOR_MUTEX.lock().unwrap();
#[cfg(feature = "no-std")]
let _lock = { CONSTRUCTOR_MUTEX.lock() };
#[cfg(not(feature = "no-std"))]
let _lock = { CONSTRUCTOR_MUTEX.lock().unwrap() };

let cpu = Cpu::seed_from_u64(seed);
let instance = Arc::new(Instance::new(InstanceDescriptor::default()));
let adapter = futures_lite::future::block_on(instance.request_adapter(&Default::default()))
Expand Down Expand Up @@ -332,7 +342,7 @@ impl<E: Unit> Storage<E> for Webgpu {
device: self.cpu.clone(),
tape: NoneTape,
};
let buf = std::sync::Arc::get_mut(&mut cpu_tensor.data).unwrap();
let buf = Arc::get_mut(&mut cpu_tensor.data).unwrap();
tensor.data.copy_to_host::<E>(&self.dev, &self.queue, buf);
self.cpu.tensor_to_vec::<S, _>(&cpu_tensor)
}
Expand Down
1 change: 1 addition & 0 deletions dfdx-core/src/tensor_ops/stack/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{shapes::*, tensor::Webgpu};
use std::vec::Vec;

impl<E: Dtype> super::StackKernel<E> for Webgpu {
fn forward<S: Shape, Num: Dim>(
Expand Down

0 comments on commit 87cb8f3

Please sign in to comment.