diff --git a/src/nn/mod.rs b/src/nn/mod.rs index ddf87b9ed..c9c22a1e4 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -190,6 +190,8 @@ mod batchnorm2d; mod bias2d; #[cfg(feature = "nightly")] mod conv; +#[cfg(feature = "nightly")] +mod rec; mod convtrans; mod dropout; mod ema; @@ -284,6 +286,10 @@ pub mod builders { pub use super::dropout::{Dropout, DropoutOneIn}; pub use super::embedding::builder::Embedding; #[cfg(feature = "nightly")] + pub use super::rec::builder::RNN; + #[cfg(feature = "nightly")] + pub use super::rec::builder::GRU; + #[cfg(feature = "nightly")] pub use super::flatten::Flatten2D; pub use super::generalized_residual::GeneralizedResidual; pub use super::layer_norm::builder::LayerNorm1D; diff --git a/src/nn/rec.rs b/src/nn/rec.rs new file mode 100644 index 000000000..c5ce53869 --- /dev/null +++ b/src/nn/rec.rs @@ -0,0 +1,705 @@ +use crate::{ + nn::linear::Linear, + prelude::{ + BuildModule, BuildOnDevice, Const, Device, Dim, Dtype, HasShape, Module, ModuleVisitor, + NonMutableModule, PutTape, SelectTo, Sigmoid, SplitTape, Tanh, Tape, Tensor, + TensorCollection, TensorFrom, TryStack, + }, +}; +use core::marker::PhantomData; +use num_traits::Float; +use rand_distr::uniform::SampleUniform; + +pub mod builder { + use crate::prelude::Dim; + use core::marker::PhantomData; + + #[derive(Debug)] + pub struct RNN< + const IN_CHAN: usize, + const OUT_CHAN: usize, + InSeq: Dim = usize, + OutSeq: Dim = usize, + > { + in_seq: PhantomData, + out_seq: PhantomData, + } + + #[derive(Debug)] + pub struct GRU< + const IN_CHAN: usize, + const OUT_CHAN: usize, + InSeq: Dim = usize, + OutSeq: Dim = usize, + > { + in_seq: PhantomData, + out_seq: PhantomData, + } +} + +impl BuildOnDevice + for builder::RNN +where + E: Dtype, + D: Device, + IS: Dim, + OS: Dim, + RNN: BuildModule, +{ + type Built = RNN; + fn try_build_on_device(device: &D) -> Result::Err> { + Self::Built::try_build(device) + } +} + +impl BuildOnDevice + for builder::GRU +where + E: Dtype, + D: Device, + IS: Dim, + OS: Dim, + GRU: BuildModule, +{ + type Built = GRU; + fn try_build_on_device(device: &D) -> Result::Err> { + Self::Built::try_build(device) + } +} + +// TODO: support batch on cell forward +pub trait RecCell, T: Tape> { + fn cell_try_forward( + &self, + x: Tensor<(Const,), E, D, T>, + h: Tensor<(Const,), E, D, T>, + ) -> Result,), E, D, T>, D::Err>; +} + +pub struct RNN, IS: Dim, OS: Dim> { + l_x: Linear, + l_h: Linear, + tanh: Tanh, + is: PhantomData, + os: PhantomData, +} + +impl TensorCollection + for RNN +where + E: Dtype + Float + SampleUniform, + D: Device, + IS: Dim, + OS: Dim, +{ + type To> = RNN; + + fn iter_tensors>( + visitor: &mut V, + ) -> Result>, V::Err> { + visitor.visit_fields( + ( + Self::module("l_x", |s| &s.l_x, |s| &mut s.l_x), + Self::module("l_h", |s| &s.l_h, |s| &mut s.l_h), + ), + |(l_x, l_h)| RNN { + l_x, + l_h, + tanh: Default::default(), + is: Default::default(), + os: Default::default(), + }, + ) + } +} + +impl, IS: Dim, OS: Dim> NonMutableModule + for RNN +{ +} + +impl, T: Tape, IS: Dim, OS: Dim> + RecCell for RNN +{ + fn cell_try_forward( + &self, + x: Tensor<(Const,), E, D, T>, + h: Tensor<(Const,), E, D, T>, + ) -> Result,), E, D, T>, D::Err> { + self.tanh + .try_forward(self.l_x.try_forward(x)? + self.l_h.try_forward(h)?) + } +} + +pub struct GRU, IS: Dim, OS: Dim> { + l_xr: Linear, + l_hr: Linear, + l_xz: Linear, + l_hz: Linear, + l_xn: Linear, + l_hn: Linear, + sigmoid: Sigmoid, + tanh: Tanh, + is: PhantomData, + os: PhantomData, +} + +impl TensorCollection + for GRU +where + E: Dtype + Float + SampleUniform, + D: Device, + IS: Dim, + OS: Dim, +{ + type To> = GRU; + + fn iter_tensors>( + visitor: &mut V, + ) -> Result>, V::Err> { + visitor.visit_fields( + ( + Self::module("l_xr", |s| &s.l_xr, |s| &mut s.l_xr), + Self::module("l_hr", |s| &s.l_hr, |s| &mut s.l_hr), + Self::module("l_xz", |s| &s.l_xz, |s| &mut s.l_xz), + Self::module("l_hz", |s| &s.l_hz, |s| &mut s.l_hz), + Self::module("l_xn", |s| &s.l_xn, |s| &mut s.l_xn), + Self::module("l_hn", |s| &s.l_hn, |s| &mut s.l_hn), + ), + |(l_xr, l_hr, l_xz, l_hz, l_xn, l_hn)| GRU { + l_xr, + l_hr, + l_xz, + l_hz, + l_xn, + l_hn, + sigmoid: Default::default(), + tanh: Default::default(), + is: Default::default(), + os: Default::default(), + }, + ) + } +} +impl, IS: Dim, OS: Dim> NonMutableModule + for GRU +{ +} +impl, T: Tape, IS: Dim, OS: Dim> + RecCell for GRU +{ + fn cell_try_forward( + &self, + x: Tensor<(Const,), E, D, T>, + h: Tensor<(Const,), E, D, T>, + ) -> Result,), E, D, T>, D::Err> { + let r = self.sigmoid.try_forward( + self.l_xr.try_forward(x.retaped::())? + self.l_hr.try_forward(h.retaped::())?, + )?; + let z = self.sigmoid.try_forward( + self.l_xz.try_forward(x.retaped::())? + self.l_hz.try_forward(h.retaped::())?, + )?; + let n = self.tanh.try_forward( + self.l_xn.try_forward(x)? + r * self.l_hn.try_forward(h.retaped::())?, + )?; + let ones = D::default().ones(); + Ok((-z.retaped() + ones) * n + z * h) + } +} + +macro_rules! cell_impls { + ($cell:ident) => { + // usize -> usize + + impl, T: Tape> + Module), E, D, T>> for $cell + { + type Output = Tensor<(usize, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: usize, output_seq: usize, batch: no + fn try_forward( + &self, + input: Tensor<(usize, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + let mut hs = vec![]; + let mut h = dev.zeros().retaped::(); + let (input, mut tape) = input.split_tape(); + for s in 0..input.shape().0 { + let x = input.retaped::().try_select(dev.tensor(s))?; + let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape(); + tape = tape.merge(tape_select_cell); + hs.push(h_new.retaped::()); + h = h_new.retaped::(); + } + let (hs, tape_stack) = hs.try_stack()?.split_tape(); + Ok(hs.put_tape(tape.merge(tape_stack))) + } + } + + impl, T: Tape> + Module), E, D, T>> for $cell + { + type Output = Tensor<(usize, usize, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: usize, output_seq: usize, batch: usize + fn try_forward( + &self, + input: Tensor<(usize, usize, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + // (batch, seq, O) + let mut hsb = vec![]; + let mut h = dev.zeros().retaped::(); + let (input, mut tape) = input.split_tape(); + for b in 0..input.shape().0 { + // (seq, O) + let mut hs = vec![]; + let seq = input.retaped::().try_select(dev.tensor(b))?; + for s in 0..seq.shape().0 { + let x = seq.retaped::().try_select(dev.tensor(s))?; + let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape(); + tape = tape.merge(tape_select_cell); + hs.push(h_new.retaped::()); + h = h_new.retaped::(); + } + hsb.push(hs.try_stack()?); + } + let (hsb, tape_stack) = hsb.try_stack()?.split_tape(); + Ok(hsb.put_tape(tape.merge(tape_stack))) + } + } + + impl< + const I: usize, + const O: usize, + const B: usize, + E: Dtype, + D: Device, + T: Tape, + > Module, usize, Const), E, D, T>> + for $cell + { + type Output = Tensor<(Const, usize, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: usize, output_seq: usize, batch: B + fn try_forward( + &self, + input: Tensor<(Const, usize, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + // HACK: better way of creating const size tensors + let init = dev.zeros_like(&(input.shape().1, Const::)); + // (batch, seq, O) + let mut hsb = [(); B].map(|_| init.retaped()); + let mut h = dev.zeros().retaped::(); + let (input, mut tape) = input.split_tape(); + for b in 0..B { + // (seq, O) + let mut hs = vec![]; + let seq = input.retaped::().try_select(dev.tensor(b))?; + for s in 0..seq.shape().0 { + let x = seq.retaped::().try_select(dev.tensor(s))?; + let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape(); + tape = tape.merge(tape_select_cell); + hs.push(h_new.retaped::()); + h = h_new.retaped::(); + } + hsb[b] = hs.try_stack()?; + } + let (hsb, tape_stack) = hsb.try_stack()?.split_tape(); + Ok(hsb.put_tape(tape.merge(tape_stack))) + } + } + + // usize -> 1 + + impl, T: Tape> + Module), E, D, T>> for $cell> + { + type Output = Tensor<(Const,), E, D, T>; + type Error = D::Err; + + /// input_seq: usize, output_seq: usize, batch: no + fn try_forward( + &self, + input: Tensor<(usize, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + let mut h = dev.zeros().retaped::(); + let (input, tape) = input.split_tape(); + for s in 0..input.shape().0 { + let x = input.retaped::().try_select(dev.tensor(s))?; + h = self.cell_try_forward(x, h)?; + } + let (h, tape_cell) = h.split_tape(); + Ok(h.put_tape(tape.merge(tape_cell))) + } + } + + impl, T: Tape> + Module), E, D, T>> + for $cell> + { + type Output = Tensor<(usize, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: usize, output_seq: 1, batch: usize + fn try_forward( + &self, + input: Tensor<(usize, usize, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + // (batch, O) + let mut hb = vec![]; + let (input, tape) = input.split_tape(); + for b in 0..input.shape().0 { + // (seq, O) + let mut h = dev.zeros().retaped::(); + let seq = input.retaped::().try_select(dev.tensor(b))?; + for s in 0..seq.shape().0 { + let x = seq.retaped::().try_select(dev.tensor(s))?; + h = self.cell_try_forward(x, h)?; + } + hb.push(h); + } + let (hb, tape_stack) = hb.try_stack()?.split_tape(); + Ok(hb.put_tape(tape.merge(tape_stack))) + } + } + + impl< + const I: usize, + const O: usize, + const B: usize, + E: Dtype, + D: Device, + T: Tape, + > Module, usize, Const), E, D, T>> + for $cell> + { + type Output = Tensor<(Const, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: usize, output_seq: usize, batch: B + fn try_forward( + &self, + input: Tensor<(Const, usize, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + // HACK: better way of creating const size tensors + let init = dev.zeros_like(&(Const::,)); + // (batch, seq, O) + let mut hb = [(); B].map(|_| init.retaped::()); + let mut h = dev.zeros().retaped::(); + let (input, mut tape) = input.split_tape(); + for b in 0..B { + // (seq, O) + let seq = input.retaped::().try_select(dev.tensor(b))?; + for s in 0..seq.shape().0 { + let x = seq.retaped::().try_select(dev.tensor(s))?; + let (h_new, tape_select_forward) = + self.cell_try_forward(x, h)?.split_tape(); + tape = tape.merge(tape_select_forward); + h = h_new.retaped(); + } + hb[b] = h.retaped::(); + } + let (hb, tape_stack) = hb.try_stack()?.split_tape(); + Ok(hb.put_tape(tape.merge(tape_stack))) + } + } + + // S -> S + + impl< + const I: usize, + const O: usize, + const S: usize, + E: Dtype, + D: Device, + T: Tape, + > Module, Const), E, D, T>> + for $cell, Const> + where + Assert<{ S > 1 }>: IsTrue, + { + type Output = Tensor<(Const, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: S, output_seq: S, batch: no + fn try_forward( + &self, + input: Tensor<(Const, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + let init = dev.zeros_like(&(Const::,)); + let mut hs = [(); S].map(|_| init.retaped::()); + let mut h = dev.zeros().retaped::(); + let (input, mut tape) = input.split_tape(); + for s in 0..S { + let x = input.retaped::().try_select(dev.tensor(s))?; + let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape(); + tape = tape.merge(tape_select_cell); + hs[s] = h_new.retaped::(); + h = h_new.retaped::(); + } + let (hs, tape_stack) = hs.try_stack()?.split_tape(); + Ok(hs.put_tape(tape.merge(tape_stack))) + } + } + + impl< + const I: usize, + const O: usize, + const S: usize, + E: Dtype, + D: Device, + T: Tape, + > Module, Const), E, D, T>> + for $cell, Const> + where + Assert<{ S > 1 }>: IsTrue, + { + type Output = Tensor<(usize, Const, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: S, output_seq: S, batch: usize + fn try_forward( + &self, + input: Tensor<(usize, Const, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + // (batch, seq, O) + let mut hsb = vec![]; + let mut h = dev.zeros().retaped::(); + let (input, mut tape) = input.split_tape(); + for b in 0..input.shape().0 { + let init = dev.zeros_like(&(Const::,)); + // (seq, O) + let mut hs = [(); S].map(|_| init.retaped::()); + let seq = input.retaped::().try_select(dev.tensor(b))?; + for s in 0..S { + let x = seq.retaped::().try_select(dev.tensor(s))?; + let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape(); + tape = tape.merge(tape_select_cell); + hs[s] = h_new.retaped::(); + h = h_new.retaped::(); + } + hsb.push(hs.try_stack()?); + } + let (hsb, tape_stack) = hsb.try_stack()?.split_tape(); + Ok(hsb.put_tape(tape.merge(tape_stack))) + } + } + + impl< + const I: usize, + const O: usize, + const S: usize, + const B: usize, + E: Dtype, + D: Device, + T: Tape, + > Module, Const, Const), E, D, T>> + for $cell, Const> + where + Assert<{ S > 1 }>: IsTrue, + { + type Output = Tensor<(Const, Const, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: S, output_seq: S, batch: B + fn try_forward( + &self, + input: Tensor<(Const, Const, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + // HACK: better way of creating const size tensors + let init = dev.zeros_like(&(Const::, Const::)); + // (batch, seq, O) + let mut hsb = [(); B].map(|_| init.retaped()); + let mut h = dev.zeros().retaped::(); + let (input, mut tape) = input.split_tape(); + for b in 0..B { + let init = dev.zeros_like(&(Const::,)); + // (seq, O) + let mut hs = [(); S].map(|_| init.retaped::()); + let seq = input.retaped::().try_select(dev.tensor(b))?; + for s in 0..S { + let x = seq.retaped::().try_select(dev.tensor(s))?; + let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape(); + tape = tape.merge(tape_select_cell); + hs[s] = h_new.retaped::(); + h = h_new.retaped::(); + } + hsb[b] = hs.try_stack()?; + } + let (hsb, tape_stack) = hsb.try_stack()?.split_tape(); + Ok(hsb.put_tape(tape.merge(tape_stack))) + } + } + + // S -> 1 + + impl< + const I: usize, + const O: usize, + const S: usize, + E: Dtype, + D: Device, + T: Tape, + > Module, Const), E, D, T>> + for $cell, Const<1>> + { + type Output = Tensor<(Const,), E, D, T>; + type Error = D::Err; + + /// input_seq: S, output_seq: 1, batch: no + fn try_forward( + &self, + input: Tensor<(Const, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + let mut h = dev.zeros().retaped::(); + let (input, tape) = input.split_tape(); + for s in 0..S { + let x = input.retaped::().try_select(dev.tensor(s))?; + h = self.cell_try_forward(x, h)?; + } + let (h, tape_select_cell) = h.split_tape(); + Ok(h.put_tape(tape.merge(tape_select_cell))) + } + } + + impl< + const I: usize, + const O: usize, + const S: usize, + E: Dtype, + D: Device, + T: Tape, + > Module, Const), E, D, T>> + for $cell, Const<1>> + { + type Output = Tensor<(usize, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: S, output_seq: 1, batch: usize + fn try_forward( + &self, + input: Tensor<(usize, Const, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + // (batch, seq, O) + let mut hb = vec![]; + let mut h = dev.zeros().retaped::(); + let (input, mut tape) = input.split_tape(); + for b in 0..input.shape().0 { + let seq = input.retaped::().try_select(dev.tensor(b))?; + for s in 0..S { + let x = seq.retaped::().try_select(dev.tensor(s))?; + let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape(); + tape = tape.merge(tape_select_cell); + h = h_new.retaped::(); + } + hb.push(h.retaped::()); + } + let (hb, tape_stack) = hb.try_stack()?.split_tape(); + Ok(hb.put_tape(tape.merge(tape_stack))) + } + } + + impl< + const I: usize, + const O: usize, + const S: usize, + const B: usize, + E: Dtype, + D: Device, + T: Tape, + > Module, Const, Const), E, D, T>> + for $cell, Const<1>> + { + type Output = Tensor<(Const, Const), E, D, T>; + type Error = D::Err; + + /// input_seq: S, output_seq: 1, batch: B + fn try_forward( + &self, + input: Tensor<(Const, Const, Const), E, D, T>, + ) -> Result { + let dev = D::default(); + // HACK: better way of creating const size tensors + let init = dev.zeros_like(&(Const::,)); + // (batch, seq, O) + let mut hb = [(); B].map(|_| init.retaped::()); + let mut h = dev.zeros().retaped::(); + let (input, mut tape) = input.split_tape(); + for b in 0..B { + let seq = input.retaped::().try_select(dev.tensor(b))?; + for s in 0..S { + let x = seq.retaped::().try_select(dev.tensor(s))?; + let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape(); + tape = tape.merge(tape_select_cell); + h = h_new.retaped::(); + } + hb[b] = h.retaped::(); + } + let (hsb, tape_stack) = hb.try_stack()?.split_tape(); + Ok(hsb.put_tape(tape.merge(tape_stack))) + } + } + }; +} + +cell_impls!(RNN); +cell_impls!(GRU); + +pub enum Assert {} +pub trait IsTrue {} +impl IsTrue for Assert {} + +#[cfg(test)] +mod tests { + use super::{builder::RNN, *}; + use crate::{ + prelude::{Const, DeviceBuildExt, Tensor, ZerosTensor}, + tests::{TestDevice, TestDtype}, + }; + + #[test] + fn test_forward() { + let dev: TestDevice = Default::default(); + let x = dev.zeros_like(&(Const::<10>, Const::<3>)); + let _: Tensor<(Const<10>, Const<1>), _, _, _> = dev + .build_module::, Const<10>>, TestDtype>() + .forward(x.clone()); + let _: Tensor<(Const<10>, Const<5>), _, _, _> = dev + .build_module::, Const<10>>, TestDtype>() + .forward(x.clone()); + let _: Tensor<(Const<5>,), _, _, _> = dev + .build_module::, Const<1>>, TestDtype>() + .forward(x.clone()); + } + + #[test] + fn test_batch_forward() { + let dev: TestDevice = Default::default(); + let x = dev.zeros_like(&(Const::<32>, Const::<10>, Const::<3>)); + let _: Tensor<(Const<32>, Const<10>, Const<1>), _, _, _> = dev + .build_module::, Const<10>>, TestDtype>() + .forward(x.clone()); + let _: Tensor<(Const<32>, Const<10>, Const<5>), _, _, _> = dev + .build_module::, Const<10>>, TestDtype>() + .forward(x.clone()); + let _: Tensor<(Const<32>, Const<5>), _, _, _> = dev + .build_module::, Const<1>>, TestDtype>() + .forward(x.clone()); + } + + // TODO: more tests +}