Skip to content

Commit

Permalink
mv and add op layers
Browse files Browse the repository at this point in the history
- Added serialization for some data from layers:
  - upscale2d
  - reshape
  - pool_2d_avg
  - pool_2d_max
  - pool_2d_min
  - leaky_relu
  - dropout
- Added layers for more tensor ops:
  - add
  - bce
  - boolean
  - broasdcast_to
  - choose
  - clamp
  - cmp
  - div
  - huber_error
  - logsumexp_to
  - max_to
  - maximum
  - min_to
  - minimum
  - mean_to
  - mul
  - nans_to
  - negate
  - normalize
  - permute_to
  - pow
  - realize_to
  - stddev_to
  - var_to
  • Loading branch information
swfsql committed Mar 3, 2024
1 parent cf811fb commit 973a2f1
Show file tree
Hide file tree
Showing 35 changed files with 847 additions and 58 deletions.
1 change: 1 addition & 0 deletions dfdx/examples/advanced-resnet18.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg_attr(feature = "nightly", feature(generic_const_exprs))]
#![allow(incomplete_features)]

#[cfg(not(feature = "nightly"))]
fn main() {
Expand Down
2 changes: 0 additions & 2 deletions dfdx/src/nn/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pub mod ops;
mod pool_global_avg;
mod pool_global_max;
mod pool_global_min;
mod reshape;
mod residual_add;
mod residual_mul;
mod split_into;
Expand Down Expand Up @@ -52,7 +51,6 @@ pub use multi_head_attention::{MultiHeadAttention, MultiHeadAttentionConfig};
pub use pool_global_avg::AvgPoolGlobal;
pub use pool_global_max::MaxPoolGlobal;
pub use pool_global_min::MinPoolGlobal;
pub use reshape::Reshape;
pub use residual_add::ResidualAdd;
pub use residual_mul::ResidualMul;
pub use split_into::SplitInto;
Expand Down
15 changes: 15 additions & 0 deletions dfdx/src/nn/layers/ops/add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::prelude::*;

/// Calls on [crate::tensor_ops::TryAdd], which for tensors is [crate::tensor_ops::add()].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Add;
impl<Lhs, Rhs> Module<(Lhs, Rhs)> for Add
where
Lhs: TryAdd<Rhs>,
{
type Output = <Lhs as TryAdd<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
x.0.try_add(x.1)
}
}
22 changes: 22 additions & 0 deletions dfdx/src/nn/layers/ops/bce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::prelude::*;

/// Calls [crate::tensor_ops::bce_with_logits].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Bce;
type Logits<S, E, D, T> = Tensor<S, E, D, T>;
type Probs<S, E, D, T> = Tensor<S, E, D, T>;

impl<S: Shape, E: Dtype, D: Device<E>, LTape: Tape<E, D>, RTape: Tape<E, D>>
Module<(Logits<S, E, D, LTape>, Probs<S, E, D, RTape>)> for Bce
where
LTape: Merge<RTape>,
{
type Output = Logits<S, E, D, LTape>;

fn try_forward(
&self,
x: (Logits<S, E, D, LTape>, Probs<S, E, D, RTape>),
) -> Result<Self::Output, Error> {
x.0.try_bce_with_logits(x.1)
}
}
50 changes: 50 additions & 0 deletions dfdx/src/nn/layers/ops/boolean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use crate::prelude::*;
use std::ops::{BitAnd, BitOr, BitXor, Not as BitNot};

/// Calls on [std::ops::BitAnd], which for booleans is [crate::tensor_ops::bool_and].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct And;

/// Calls on [std::ops::Not], which for booleans is [crate::tensor_ops::bool_not].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Not;

/// Calls on [std::ops::BitOr], which for booleans is [crate::tensor_ops::bool_or].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Or;

/// Calls on [std::ops::BitXor], which for booleans is [crate::tensor_ops::bool_xor].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Xor;

impl<Lhs: BitAnd<Rhs>, Rhs> Module<(Lhs, Rhs)> for And {
type Output = <Lhs as BitAnd<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
Ok(x.0 & x.1)
}
}

impl<Input: BitNot> Module<Input> for Not {
type Output = <Input as BitNot>::Output;

fn try_forward(&self, x: Input) -> Result<Self::Output, Error> {
Ok(!x)
}
}

impl<Lhs: BitOr<Rhs>, Rhs> Module<(Lhs, Rhs)> for Or {
type Output = <Lhs as BitOr<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
Ok(x.0 | x.1)
}
}

impl<Lhs: BitXor<Rhs>, Rhs> Module<(Lhs, Rhs)> for Xor {
type Output = <Lhs as BitXor<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
Ok(x.0 ^ x.1)
}
}
28 changes: 28 additions & 0 deletions dfdx/src/nn/layers/ops/broadcast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use crate::prelude::*;
use std::fmt::Debug;

/// Calls on [crate::tensor_ops::BroadcastTo].
#[derive(Clone, Copy, Debug, Default, ResetParams, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Broadcast<Dst: Shape, Ax: Debug>(
#[cfg_attr(feature = "safetensors", serialize)] pub Dst,
#[cfg_attr(feature = "safetensors", serialize)] pub Ax,
);

impl<S: Shape, Ax: Axes + Debug, E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Broadcast<S, Ax> {
type Built = Self;
fn try_build_on_device(&self, _device: &D) -> Result<Self::Built, crate::tensor::Error> {
Ok(*self)
}
}

impl<Dst: Shape, Ax: Axes + Debug, Input> Module<Input> for Broadcast<Dst, Ax>
where
Input: BroadcastTo,
Dst: ReduceShapeTo<<Input as HasShape>::Shape, Ax>,
{
type Output = <Input as HasShape>::WithShape<Dst>;
fn try_forward(&self, x: Input) -> Result<Self::Output, Error> {
x.try_broadcast_like(&self.0)
}
}
52 changes: 52 additions & 0 deletions dfdx/src/nn/layers/ops/choose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use crate::prelude::*;

/// Calls on [crate::tensor_ops::ChooseFrom].
#[derive(Clone, Debug)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Choose<S: Shape, Dev: Storage<bool>> {
#[cfg_attr(feature = "safetensors", serialize)]
pub choose: Tensor<S, bool, Dev>,
}

impl<S: Shape, Elem: Dtype, Dev: Device<Elem>> ::dfdx::nn_traits::ResetParams<Elem, Dev>
for Choose<S, Dev>
{
fn try_reset_params(&mut self) -> Result<(), ::dfdx::tensor::Error> {
Ok(())
}
}

impl<S: Shape, Elem: Dtype, Dev: Device<Elem>> ::dfdx::nn_traits::UpdateParams<Elem, Dev>
for Choose<S, Dev>
{
fn try_update_params<_Model, Optim: ::dfdx::nn_traits::Optimizer<_Model, Elem, Dev>>(
&mut self,
_optimizer: &mut Optim,
_gradients: &::dfdx::tensor::Gradients<Elem, Dev>,
_missing_tensors: &mut Vec<::dfdx::tensor::UniqueId>,
) -> Result<(), ::dfdx::tensor::Error> {
Ok(())
}
}

impl<S: Shape, Elem: Dtype, Dev: Device<Elem>> ::dfdx::nn_traits::ZeroGrads<Elem, Dev>
for Choose<S, Dev>
{
fn try_zero_grads(
&self,
_grads: &mut ::dfdx::prelude::Gradients<Elem, Dev>,
) -> Result<(), ::dfdx::tensor::Error> {
Ok(())
}
}

impl<S: Shape, Lhs, Rhs, Dev: Storage<bool>> Module<(Lhs, Rhs)> for Choose<S, Dev>
where
Tensor<S, bool, Dev>: ChooseFrom<Lhs, Rhs>,
{
type Output = <Tensor<S, bool, Dev> as ChooseFrom<Lhs, Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
self.choose.clone().try_choose(x.0, x.1)
}
}
30 changes: 30 additions & 0 deletions dfdx/src/nn/layers/ops/clamp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::prelude::*;
use std::fmt::Debug;

/// Calls [crate::tensor_ops::clamp].
#[derive(Clone, Copy, Debug, ResetParams, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Clamp<E> {
#[cfg_attr(feature = "safetensors", serialize)]
pub min: E,
#[cfg_attr(feature = "safetensors", serialize)]
pub max: E,
}

impl<E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Clamp<E> {
type Built = Self;
fn try_build_on_device(&self, _device: &D) -> Result<Self::Built, crate::tensor::Error> {
Ok(*self)
}
}

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Clamp<E>
where
E: Into<f64>,
{
type Output = Tensor<S, E, D, T>;

fn try_forward(&self, x: Tensor<S, E, D, T>) -> Result<Self::Output, Error> {
x.try_clamp(self.min, self.max)
}
}
73 changes: 73 additions & 0 deletions dfdx/src/nn/layers/ops/cmp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use crate::prelude::*;

/// Calls on [crate::tensor_ops::TryEq], which for booleans is [crate::tensor_ops::eq].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Eq;

/// Calls on [crate::tensor_ops::TryNe], which for booleans is [crate::tensor_ops::ne].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Ne;

/// Calls on [crate::tensor_ops::TryGt], which for booleans is [crate::tensor_ops::gt].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Gt;

/// Calls on [crate::tensor_ops::TryGe], which for booleans is [crate::tensor_ops::ge].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Ge;

/// Calls on [crate::tensor_ops::TryLt], which for booleans is [crate::tensor_ops::lt].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Lt;

/// Calls on [crate::tensor_ops::TryLe], which for booleans is [crate::tensor_ops::le].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Le;

impl<Lhs: TryEq<Rhs>, Rhs> Module<(Lhs, Rhs)> for Eq {
type Output = <Lhs as TryEq<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
x.0.try_eq(x.1)
}
}

impl<Lhs: TryNe<Rhs>, Rhs> Module<(Lhs, Rhs)> for Ne {
type Output = <Lhs as TryNe<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
x.0.try_ne(x.1)
}
}

impl<Lhs: TryGt<Rhs>, Rhs> Module<(Lhs, Rhs)> for Gt {
type Output = <Lhs as TryGt<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
x.0.try_gt(x.1)
}
}

impl<Lhs: TryGe<Rhs>, Rhs> Module<(Lhs, Rhs)> for Ge {
type Output = <Lhs as TryGe<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
x.0.try_ge(x.1)
}
}

impl<Lhs: TryLt<Rhs>, Rhs> Module<(Lhs, Rhs)> for Lt {
type Output = <Lhs as TryLt<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
x.0.try_lt(x.1)
}
}

impl<Lhs: TryLe<Rhs>, Rhs> Module<(Lhs, Rhs)> for Le {
type Output = <Lhs as TryLe<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
x.0.try_le(x.1)
}
}
15 changes: 15 additions & 0 deletions dfdx/src/nn/layers/ops/div.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::prelude::*;

/// Calls on [crate::tensor_ops::TryDiv], which for tensors is [crate::tensor_ops::div()].
#[derive(Default, Debug, Clone, Copy, crate::nn::CustomModule)]
pub struct Div;
impl<Lhs, Rhs> Module<(Lhs, Rhs)> for Div
where
Lhs: TryDiv<Rhs>,
{
type Output = <Lhs as TryDiv<Rhs>>::Output;

fn try_forward(&self, x: (Lhs, Rhs)) -> Result<Self::Output, Error> {
x.0.try_div(x.1)
}
}
11 changes: 10 additions & 1 deletion dfdx/src/nn/layers/ops/dropout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ impl<const N: usize, S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Ten
/// let r = dropout.forward_mut(x.trace(grads));
/// assert_eq!(r.array(), [[2.0, 0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0, 2.0]]);
/// ```
#[derive(Clone, Debug, CustomModule)]
#[derive(Clone, Debug, ResetParams, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct Dropout {
#[cfg_attr(feature = "safetensors", serialize)]
pub p: f64,
}

Expand All @@ -67,6 +69,13 @@ impl Default for Dropout {
}
}

impl<E: Dtype, D: Device<E>> BuildOnDevice<E, D> for Dropout {
type Built = Self;
fn try_build_on_device(&self, _device: &D) -> Result<Self::Built, crate::tensor::Error> {
Ok(self.clone())
}
}

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Dropout {
type Output = Tensor<S, E, D, T>;

Expand Down
33 changes: 33 additions & 0 deletions dfdx/src/nn/layers/ops/huber_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use crate::prelude::*;
use std::fmt::Debug;

/// Calls [crate::tensor_ops::huber_error].
#[derive(Clone, Copy, Debug, ResetParams, UpdateParams, ZeroGrads)]
#[cfg_attr(feature = "safetensors", derive(SaveSafeTensors, LoadSafeTensors))]
pub struct HuberError<E> {
#[cfg_attr(feature = "safetensors", serialize)]
pub delta: E,
}

impl<E: Dtype, D: Device<E>> BuildOnDevice<E, D> for HuberError<E> {
type Built = Self;
fn try_build_on_device(&self, _device: &D) -> Result<Self::Built, crate::tensor::Error> {
Ok(*self)
}
}

impl<S: Shape, E: Dtype, D: Device<E>, Lt: Tape<E, D>, Rt: Tape<E, D>>
Module<(Tensor<S, E, D, Lt>, Tensor<S, E, D, Rt>)> for HuberError<E>
where
Lt: Merge<Rt>,
E: Into<f64>,
{
type Output = Tensor<S, E, D, Lt>;

fn try_forward(
&self,
x: (Tensor<S, E, D, Lt>, Tensor<S, E, D, Rt>),
) -> Result<Self::Output, Error> {
x.0.try_huber_error(x.1, self.delta)
}
}
Loading

0 comments on commit 973a2f1

Please sign in to comment.