Skip to content

Commit

Permalink
avoid conv1d bound for cudnn
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Feb 7, 2024
1 parent b0cef58 commit 9d34c62
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions dfdx-core/src/tensor_ops/utilities/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,33 +120,60 @@ pub trait Device<E: Dtype>:
+ crate::tensor_ops::axpy::AxpyKernel<E>

// conv1d
+ super::super::conv1d::Conv1DKernel<E>
+ NonCudnnCuda<E>
{
}

#[cfg(feature = "cudnn")]
pub trait NonCudnnCuda<E: Dtype> {}

#[cfg(not(feature = "cudnn"))]
pub trait NonCudnnCuda<E: Dtype>:
// conv1d
super::super::conv1d::Conv1DKernel<E>
{
}

#[cfg(feature = "f16")]
impl Device<f16> for crate::tensor::Cpu {}
#[cfg(feature = "f16")]
impl Device<AMP<f16>> for crate::tensor::Cpu {}
mod f16 {
use super::*;
impl Device<f16> for crate::tensor::Cpu {}
impl NonCudnnCuda<f16> for crate::tensor::Cpu {}
impl Device<AMP<f16>> for crate::tensor::Cpu {}
}
impl Device<f32> for crate::tensor::Cpu {}
impl NonCudnnCuda<f32> for crate::tensor::Cpu {}
impl Device<f64> for crate::tensor::Cpu {}
impl NonCudnnCuda<f64> for crate::tensor::Cpu {}

#[cfg(all(feature = "cuda", feature = "f16"))]
impl Device<f16> for crate::tensor::Cuda {}
#[cfg(all(feature = "cuda", feature = "f16"))]
impl Device<AMP<f16>> for crate::tensor::Cuda {}
#[cfg(feature = "cuda")]
impl Device<f32> for crate::tensor::Cuda {}
mod cuda_f16 {
use super::*;
impl Device<f16> for crate::tensor::Cuda {}
impl NonCudnnCuda<f16> for crate::tensor::Cuda {}
impl Device<AMP<f16>> for crate::tensor::Cuda {}
impl NonCudnnCuda<AMP<f16>> for crate::tensor::Cuda {}
}
#[cfg(feature = "cuda")]
impl Device<f64> for crate::tensor::Cuda {}
mod cuda {
use super::*;
impl Device<f32> for crate::tensor::Cuda {}
impl NonCudnnCuda<f32> for crate::tensor::Cuda {}
impl Device<f64> for crate::tensor::Cuda {}
impl NonCudnnCuda<f64> for crate::tensor::Cuda {}
}

// TODO: How can we implement this for f16 when WGSL doesn't support f16 yet?
// #[cfg(all(feature = "webgpu", feature = "f16"))]
// impl Device<f16> for crate::tensor::Webgpu {}
// #[cfg(all(feature = "webgpu", feature = "f16"))]
// impl Device<AMP<f16>> for crate::tensor::Webgpu {}
#[cfg(feature = "webgpu")]
impl Device<f32> for crate::tensor::Webgpu {}
mod webgpu {
use super::*;
impl Device<f32> for crate::tensor::Webgpu {}
impl NonCudnnCuda<f32> for crate::tensor::Webgpu {}
}

// TODO: How can we implement this for f64 when WGSL doesn't support f64 yet?
// #[cfg(feature = "webgpu")]
Expand Down

0 comments on commit 9d34c62

Please sign in to comment.