Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Please minimize the requirements for the optimizers #840

Closed
emchristiansen opened this issue Aug 2, 2023 · 1 comment · Fixed by #854
Closed

Please minimize the requirements for the optimizers #840

emchristiansen opened this issue Aug 2, 2023 · 1 comment · Fixed by #854

Comments

@emchristiansen
Copy link

Adam and the other optimizers expect the thing they're updating to impl TensorCollection, e.g. in this signature:

impl<M: TensorCollection<E, D>, D: Device<E>, E: Dtype> Optimizer<M, D, E> for Adam<M, E, D> {
    fn update(
        &mut self,
        module: &mut M,
        gradients: &Gradients<E, D>,
    ) -> Result<(), OptimizerUpdateError<D::Err>>;

   ...
}

Here, TensorCollection requires the implementation of iter_tensors:

/// Type alias that specifies the how a module's type changes when using a different dtype and/or
/// device.
type To<E2: Dtype, D2: Device<E2>>;

/// Specifies how to iterate through tensors or modules containted within this module, and how
/// to contruct this module given values for its fields. Returns `Err(_)` to indicate an error,
/// `Ok(None)` to indicate that there is no error and a module has not been built, and
/// `Ok(Some(_))` contains `Self::Output<E2, D2>`
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
    visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>;

AFAICT this requirement is overkill, as the optimizers don't need to reconstruct the module from its constituent tensors; they just need to mutate the tensors in place.
Also, though not stated here, it appears to assume there are default constructors for the constituent tensors, which is not generally true (see #839).

So, perhaps the optimizers should be refactored to rely on a trait that merely visits and mutates tensors inside existing modules?

@coreylowman
Copy link
Owner

Yep this will be addressed in nn rewrite

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants