-
-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding basic RNN implementation #798
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution! Aside from my review comments, I have some concerns about how this is currently implemented.
- This implementation does not support continuing evaluation where it left off, which makes it difficult to make interactive applications, train on long sequences, etc.
- Passing sequences in as Tensors, splitting them into individual tokens, and merging the result back into a tensor feels unnatural to me; sequences should be passed in as vectors instead.
@coreylowman @xubaiw I am curious to hear your thoughts on the following implementation of RNNs:
- Have an object that represents an RNN's state, like
struct RnnState {
tensors: BTreeMap<UniqueId, Box<dyn Any>>
}
- Have RNN modules implement
Module<(RnnState, Input), Output = (RnnState, Output)>
- We can also add a module that wraps normal
Input -> Output
modules and passes the RnnState forward unmodified.
- We can also add a module that wraps normal
- RNN modules access and modify their state by storing UniqueIds corresponding to each state tensor, and using
Any::downcast
. If an object of the wrong type exists, the program panics, and if no object exists, it creates a new tensor containing zeros.
self.tanh | ||
.try_forward(self.l_x.try_forward(x)? + self.l_h.try_forward(h)?) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.tanh | |
.try_forward(self.l_x.try_forward(x)? + self.l_h.try_forward(h)?) | |
(self.l_x.try_forward(x)? + self.l_h.try_forward(h)?).try_tanh() |
You can remove self.tanh and other stored activation functions.
pub struct RNN<const I: usize, const O: usize, E: Dtype, D: Device<E>, IS: Dim, OS: Dim> { | ||
l_x: Linear<I, O, E, D>, | ||
l_h: Linear<O, O, E, D>, | ||
tanh: Tanh, | ||
is: PhantomData<IS>, | ||
os: PhantomData<OS>, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The IS and OS type variables should be removed from this, because they arbitrarily limit how these modules can be used. Instead, I think we should assume that OS is always the same as IS, and implement a separate "SelectLast" operation/module that gets the last element from these sequences.
/// input_seq: usize, output_seq: usize, batch: usize | ||
fn try_forward( | ||
&self, | ||
input: Tensor<(usize, usize, Const<I>), E, D, T>, | ||
) -> Result<Self::Output, Self::Error> { | ||
let dev = D::default(); | ||
// (batch, seq, O) | ||
let mut hsb = vec![]; | ||
let mut h = dev.zeros().retaped::<T>(); | ||
let (input, mut tape) = input.split_tape(); | ||
for b in 0..input.shape().0 { | ||
// (seq, O) | ||
let mut hs = vec![]; | ||
let seq = input.retaped::<T>().try_select(dev.tensor(b))?; | ||
for s in 0..seq.shape().0 { | ||
let x = seq.retaped::<T>().try_select(dev.tensor(s))?; | ||
let (h_new, tape_select_cell) = self.cell_try_forward(x, h)?.split_tape(); | ||
tape = tape.merge(tape_select_cell); | ||
hs.push(h_new.retaped::<T>()); | ||
h = h_new.retaped::<T>(); | ||
} | ||
hsb.push(hs.try_stack()?); | ||
} | ||
let (hsb, tape_stack) = hsb.try_stack()?.split_tape(); | ||
Ok(hsb.put_tape(tape.merge(tape_stack))) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and the impl below can be deduplicated by having a type parameter B: Dim
and input.shape().concrete()[0]
@nkoppel, I'd like to further investigate #204. Would you mind explaining what did you mean by this comment?
|
It has been a few months since I worked on dfdx, but I'll do my best. The issue I was explaining with this comment was that this implementation does not store the hidden state for RNNs in any way, instead passing in the entire sequence at once like a transformer. This is very inefficient for some use cases of RNNs. For example, to run a generative language model, you would have to recompute every previous hidden state before you could generate the next character/token, instead of simply using the hidden state you've computed already for the last character/token. My proposal was an effort to resolve this by passing in the state of the entire RNN into its input and having each RNN module retrieve and update its hidden state during inference. I reasoned that this would allow a lot of flexibility in how the network is used and would avoid mutating the network during inference. |
Resolves #204
Usable, but far from an ergonomic design. Many workarounds and hacks are employed.
Ideal design
Low level: cells
On the low level, we can construct several cells like
struct RNNCell<I, O>
,struct LSTMCell<I, O>
andstruct GRUCell<I, O>
.All of these cells
impl
s tupled forwards:Not solved: How LSTM fit into this model (how to batch both$h_{t-1}$ and $c_{t-1}$ ?)?
The cell level components allow users to keep hidden states themselves and operate on it freely:
High level: recursor
Upon the low-level cells we can build a high-level recursor:
where
Cell
can be any struct that implements single and batch forwards above. User defined structs are also allowed.With this wrapper we can easily integrate with tuple (sequential) model:
The
InSeq
andOutSeq
dims are used to control sequence length. The following sequence length are implemented:Rec<C, usize, usize>
: sequence to sequence ((usize, Const<I>)
to(usize, Const<O>)
)Rec<C, usize, Const<1>>
: sequence to last ((usize, Const<I>)
toRank1<O>
)Rec<C, Const<S>, Const<S>>
: exact sequence to exact sequence (Rank2<S, I>
toRank2<S, O>
)Rec<C, Const<S>, Const<1>>
: exact sequence to last (Rank2<S, I>
toRank1<O>
)Rec<C, Const<1>, Const<S>>
: first to exact sequence (Rank1<I>
toRank2<S, O>
)There is also batched version of each of the impls above.
E.g.
Rec<C, Const<S>, Const<1>
:(usize, Const<S>, Const<I>) -> (usize, Const<O>)
Rank3<B, S, I> -> Rank2<B, O>
.Current workaround
Currently, rust requires const generics in
impl
to appear either intrait
or instruct
. Sois not allowed and will give an unconstrained const generic error on
O
.So the current workaround does not use a
struct Rec
, rather deriveModule
implementation directly on each cell.Solution
@coreylowman
trait Module<Input>
totrait Module<Input, Output>
to overcome rust's limitation on const generics?OwnedTape
is not clonable now (although Clone for GradientTape & OwnedTape #251 says it is), while handmadeTape
management adds a lot of verbosity and is very error prone.Can we refactor
BackwardOp
fromBox
toArc
to allow cloning ofOwnedTape
?