Skip to content

Commit

Permalink
fix: add WebGPUNativeType
Browse files Browse the repository at this point in the history
  • Loading branch information
DonIsaac committed Jan 7, 2024
1 parent 9fe3681 commit 1bff6c9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 16 deletions.
2 changes: 2 additions & 0 deletions dfdx-core/src/tensor/webgpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod allocate;
mod device;
mod types;

pub use device::Buffer;
pub use device::Webgpu;
pub use types::*;

#[cfg(test)]
mod tests {
Expand Down
56 changes: 56 additions & 0 deletions dfdx-core/src/tensor/webgpu/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use crate::shapes::Unit;

/// A primitive data type natively supported by WebGPU.
///
/// See: https://www.w3.org/TR/WGSL/#types
///
/// todo: support packed types
pub trait WebgpuNativeType : Unit {
/// Name of the data type in WGSL.
const NAME: &'static str;
}

macro_rules! webgpu_type {
($RustTy:ty) => {
impl WebgpuNativeType for $RustTy {
const NAME: &'static str = stringify!($RustTy);
}
};
($RustTy:ty, $WgpuTy:expr) => {
impl WebgpuNativeType for $RustTy {
const NAME: &'static str = $WgpuTy;
}
};
}

/*
see:
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F16
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F64
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_I16
*/
#[cfg(feature = "f16")]
webgpu_type!(half::f16, "f16");
webgpu_type!(f32);
// todo: only enable when f64 feature is enabled
#[cfg(feature = "f64")]
webgpu_type!(f64);

#[cfg(feature = "i16")]
webgpu_type!(i16);
webgpu_type!(i32);

webgpu_type!(u32);
webgpu_type!(bool);

pub(crate) trait HasGlslType {
const TYPE: &'static str;
}

impl HasGlslType for f32 {
const TYPE: &'static str = "float";
}

impl HasGlslType for f64 {
const TYPE: &'static str = "double";
}
22 changes: 6 additions & 16 deletions dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ use crate::{
use core::any::TypeId;
use std::{borrow::Cow, marker::PhantomData, sync::Arc, vec::Vec};

use wgpu::{
BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages,
};


/// Creates a [`BindGroup`] for a pipeline from a set of [`wgpu::BindingResource`]s.
macro_rules! webgpu_params {
($self:expr, $pipeline:expr; $($x:expr),+ $(,)? ) => {
Expand Down Expand Up @@ -72,6 +77,7 @@ macro_rules! webgpu_unary {
}
};
}
pub(crate) use webgpu_unary;

/// Zero-sized marker type for forward pass TypeId
#[derive(Debug, Default)]
Expand All @@ -85,22 +91,6 @@ pub(crate) struct Backward<E: Dtype, K> {
_phantom: PhantomData<(E, K)>,
}

pub(crate) trait HasGlslType {
const TYPE: &'static str;
}

impl HasGlslType for f32 {
const TYPE: &'static str = "float";
}

impl HasGlslType for f64 {
const TYPE: &'static str = "double";
}

pub(crate) use webgpu_unary;
use wgpu::{
BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages,
};

impl<E: Dtype + HasGlslType, K: UnaryOpWebgpuKernel<E> + 'static> UnaryKernel<K, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX;
Expand Down

0 comments on commit 1bff6c9

Please sign in to comment.