Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding basic RNN implementation #798

Closed
wants to merge 2 commits into from
Closed

Conversation

xubaiw
Copy link

@xubaiw xubaiw commented Jun 25, 2023

Resolves #204

Usable, but far from an ergonomic design. Many workarounds and hacks are employed.

Ideal design

Low level: cells

On the low level, we can construct several cells like struct RNNCell<I, O>, struct LSTMCell<I, O> and struct GRUCell<I, O>.

All of these cells impls tupled forwards:

  1. single forward:
impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T: Tape<E, D>>
Module<(Tensor<Rank1<I>, E, D, T>, Tensor<Rank1<O>, E, D, T>)>
for RNNCell<I, O> {
    type Output = Tensor<Rank1<O>, E, E, T>;
}
  1. batch forward:
impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T: Tape<E, D>>
Module<(Tensor<(usize, Const<I>), E, D, T>, Tensor<(usize, Const<O>), E, D, T>)>
for RNNCell<I, O> {
    type Output = Tensor<(usize, Const<O>), E, D, T>;
}
  1. exact batch forward:
impl<const I: usize, const O: usize, const B: usize, E: Dtype, D: Device<E>, T: Tape<E, D>>
Module<(Tensor<Rank2<B, I>, E, D, T>, Tensor<Rank2<B, O>, E, D, T>)>
for RNNCell<I, O> {
    type Output = Tensor<Rank2<B, O>, E, D, T>;
}

Not solved: How LSTM fit into this model (how to batch both $h_{t-1}$ and $c_{t-1}$?)?

The cell level components allow users to keep hidden states themselves and operate on it freely:

const I: usize = 3;
const O: usize = 5;

let rnnc = dev.build_module::<RNNCell<I, O>>();
let mut h: Tensor<Rank1<O>> = dev.zeros();

for x in xs {
  h = rnnc.forward((x, h));
}

High level: recursor

Upon the low-level cells we can build a high-level recursor:

mod builder {
    struct Rec<Cell, InSeq: Dim = usize, OutSeq: Dim = usize>;
}

struct Rec<Cell, E: DType, D: Device<E>,  InSeq: Dim = usize, OutSeq: Dim = usize>;

where Cell can be any struct that implements single and batch forwards above. User defined structs are also allowed.

With this wrapper we can easily integrate with tuple (sequential) model:

// e.g. take a sequence of stock prices `(usize, Const<3>)` and predict bull/bear `Rank1<2>`
type Mlp = (
    Rec<GRUCell<3, 32>, usize, Const<1>>,
    Linear<32, 2>,
    Softmax
);

The InSeq and OutSeq dims are used to control sequence length. The following sequence length are implemented:

  • Rec<C, usize, usize>: sequence to sequence ((usize, Const<I>) to (usize, Const<O>))
  • Rec<C, usize, Const<1>>: sequence to last ((usize, Const<I>) to Rank1<O>)
  • Rec<C, Const<S>, Const<S>>: exact sequence to exact sequence (Rank2<S, I> to Rank2<S, O>)
  • Rec<C, Const<S>, Const<1>>: exact sequence to last (Rank2<S, I> to Rank1<O>)
  • Rec<C, Const<1>, Const<S>>: first to exact sequence (Rank1<I> to Rank2<S, O>)

There is also batched version of each of the impls above.
E.g. Rec<C, Const<S>, Const<1>:

  • batched version: (usize, Const<S>, Const<I>) -> (usize, Const<O>)
  • exact batched version: Rank3<B, S, I> -> Rank2<B, O>.

Current workaround

Currently, rust requires const generics in impl to appear either in trait or in struct. So

impl<const I: usize, const O: usize, C, E, D, T, IS, OS>
Model<Tensor<(usize, Const<I>), E, D, T>> for Rec<C, IS, OS>
where
    C: Module<(Tensor<Rank1<I>, E, D, T>, Tensor<Rank1<O>, E, D, T>), Output = Tensor<Rank1<O>, E, D, T>>,
    E: DType,
    D: Device<E>,
    T: Tape<E, D>,
    IS: Dim,
    OS: Dim
{
    ...
}

is not allowed and will give an unconstrained const generic error on O.

So the current workaround does not use a struct Rec, rather derive Module implementation directly on each cell.

Solution

@coreylowman

  1. Can we refactor trait Module<Input> to trait Module<Input, Output> to overcome rust's limitation on const generics?
  2. By the way, OwnedTape is not clonable now (although Clone for GradientTape & OwnedTape #251 says it is), while handmade Tape management adds a lot of verbosity and is very error prone.
    Can we refactor BackwardOp from Box to Arc to allow cloning of OwnedTape?

Copy link
Contributor

@nkoppel nkoppel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution! Aside from my review comments, I have some concerns about how this is currently implemented.

  • This implementation does not support continuing evaluation where it left off, which makes it difficult to make interactive applications, train on long sequences, etc.
  • Passing sequences in as Tensors, splitting them into individual tokens, and merging the result back into a tensor feels unnatural to me; sequences should be passed in as vectors instead.

@coreylowman @xubaiw I am curious to hear your thoughts on the following implementation of RNNs:

  • Have an object that represents an RNN's state, like
struct RnnState {
    tensors: BTreeMap<UniqueId, Box<dyn Any>>
}
  • Have RNN modules implement Module<(RnnState, Input), Output = (RnnState, Output)>
    • We can also add a module that wraps normal Input -> Output modules and passes the RnnState forward unmodified.
  • RNN modules access and modify their state by storing UniqueIds corresponding to each state tensor, and using Any::downcast. If an object of the wrong type exists, the program panics, and if no object exists, it creates a new tensor containing zeros.

Comment on lines +129 to +130
self.tanh
.try_forward(self.l_x.try_forward(x)? + self.l_h.try_forward(h)?)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.tanh
.try_forward(self.l_x.try_forward(x)? + self.l_h.try_forward(h)?)
(self.l_x.try_forward(x)? + self.l_h.try_forward(h)?).try_tanh()

You can remove self.tanh and other stored activation functions.

Comment on lines +79 to +85
pub struct RNN<const I: usize, const O: usize, E: Dtype, D: Device<E>, IS: Dim, OS: Dim> {
l_x: Linear<I, O, E, D>,
l_h: Linear<O, O, E, D>,
tanh: Tanh,
is: PhantomData<IS>,
os: PhantomData<OS>,
}
Copy link
Contributor

@nkoppel nkoppel Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IS and OS type variables should be removed from this, because they arbitrarily limit how these modules can be used. Instead, I think we should assume that OS is always the same as IS, and implement a separate "SelectLast" operation/module that gets the last element from these sequences.

Comment on lines +247 to +273
/// input_seq: usize, output_seq: usize, batch: usize
fn try_forward(
&self,
input: Tensor<(usize, usize, Const<I>), E, D, T>,
) -> Result<Self::Output, Self::Error> {
let dev = D::default();
// (batch, seq, O)
let mut hsb = vec![];
let mut h = dev.zeros().retaped::<T>();
let (input, mut tape) = input.split_tape();
for b in 0..input.shape().0 {
// (seq, O)
let mut hs = vec![];
let seq = input.retaped::<T>().try_select(dev.tensor(b))?;
for s in 0..seq.shape().0 {
let x = seq.retaped::<T>().try_select(dev.tensor(s))?;
let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape();
tape = tape.merge(tape_select_cell);
hs.push(h_new.retaped::<T>());
h = h_new.retaped::<T>();
}
hsb.push(hs.try_stack()?);
}
let (hsb, tape_stack) = hsb.try_stack()?.split_tape();
Ok(hsb.put_tape(tape.merge(tape_stack)))
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the impl below can be deduplicated by having a type parameter B: Dim and input.shape().concrete()[0]

@xubaiw xubaiw closed this Aug 12, 2023
@JRazek
Copy link

JRazek commented Nov 30, 2023

@nkoppel, I'd like to further investigate #204. Would you mind explaining what did you mean by this comment?
I'm not sure when exactly such computation could be left off. Are there any examples in the existing code?

This implementation does not support continuing evaluation where it left off, which makes it difficult to make interactive applications, train on long sequences, etc.

@nkoppel
Copy link
Contributor

nkoppel commented Nov 30, 2023

It has been a few months since I worked on dfdx, but I'll do my best. The issue I was explaining with this comment was that this implementation does not store the hidden state for RNNs in any way, instead passing in the entire sequence at once like a transformer.

This is very inefficient for some use cases of RNNs. For example, to run a generative language model, you would have to recompute every previous hidden state before you could generate the next character/token, instead of simply using the hidden state you've computed already for the last character/token.

My proposal was an effort to resolve this by passing in the state of the entire RNN into its input and having each RNN module retrieve and update its hidden state during inference. I reasoned that this would allow a lot of flexibility in how the network is used and would avoid mutating the network during inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add RNN and LSTM modules in nn
4 participants