Skip to content

Commit

Permalink
Merge branch 'main' into don/feat/wgpu_to_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jan 25, 2024
2 parents a1fcda8 + e04dd4f commit 16658b2
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
5 changes: 3 additions & 2 deletions dfdx-core/src/tensor_ops/conv1d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod tests;

#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub(super) struct Conv1DOp {
pub struct Conv1DOp {
pub kernel: usize,
pub stride: usize,
pub padding: usize,
Expand All @@ -22,7 +22,7 @@ pub(super) struct Conv1DOp {
pub l_out: usize,
}

pub(super) trait Conv1DKernel<E: Dtype>: Storage<E> {
pub trait Conv1DKernel<E: Dtype>: Storage<E> {
fn alloc<S: Shape>(&self, s: S) -> Result<Tensor<S, E, Self>, Error>;

fn forward<L: Shape, R: Shape, O: Shape>(
Expand Down Expand Up @@ -108,6 +108,7 @@ pub trait TryConv1D<Stride, Padding, Dilation, Groups>: Sized {
) -> Result<Self::Convolved, Error>;
}

#[cfg(feature = "nightly")]
impl<
const KERNEL: usize,
const STRIDE: usize,
Expand Down
2 changes: 0 additions & 2 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,7 @@ pub use upscale2d::{
};
pub use var_to::VarTo;

#[cfg(feature = "nightly")]
mod conv1d;
#[cfg(feature = "nightly")]
pub use conv1d::TryConv1D;

#[cfg(feature = "nightly")]
Expand Down
3 changes: 3 additions & 0 deletions dfdx-core/src/tensor_ops/utilities/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ pub trait Device<E: Dtype>:
+ BinaryKernel<super::super::maximum::MaximumKernelOp, E>
+ BinaryKernel<super::super::minimum::MinimumKernelOp, E>
+ crate::tensor_ops::axpy::AxpyKernel<E>

// conv1d
+ super::super::conv1d::Conv1DKernel<E>
{
}

Expand Down
2 changes: 1 addition & 1 deletion dfdx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@
//! | Optimizer | dfdx | pytorch |
//! | --- | --- | --- |
//! | SGD | [nn::optim::Sgd] | `torch.optim.SGD` |
//! | Adam | [nn::optim::Adam] | torch.optim.Adam` |
//! | Adam | [nn::optim::Adam] | `torch.optim.Adam` |
//! | AdamW | [nn::optim::Adam] with [nn::optim::WeightDecay::Decoupled] | `torch.optim.AdamW` |
//! | RMSprop | [nn::optim::RMSprop] | `torch.optim.RMSprop` |
//!
Expand Down

0 comments on commit 16658b2

Please sign in to comment.