From c6da1a1709f07c392a4b0462bb8857c5619647f9 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 17:48:04 -0500 Subject: [PATCH 01/10] add infrastructure for hyperparameter support for TensorCollections; add Gradients::as_model --- src/nn/tensor_collection/collection.rs | 51 +++++++++++++++++++++- src/nn/tensor_collection/mod.rs | 6 +-- src/nn/tensor_collection/visitor.rs | 26 +++++++++++- src/nn/tensor_collection/visitor_impls.rs | 52 ++++++++++++++++++++--- src/shapes/shape.rs | 1 + src/tensor/numpy.rs | 2 +- 6 files changed, 126 insertions(+), 12 deletions(-) diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs index fafe51f8e..bb8ec6def 100644 --- a/src/nn/tensor_collection/collection.rs +++ b/src/nn/tensor_collection/collection.rs @@ -6,7 +6,7 @@ use crate::{ tensor_ops::Device, }; -use super::{ModuleField, ModuleFields, TensorField}; +use super::{ModuleField, ModuleFields, TensorField, HyperparameterField}; /// A collection of named tensors. Implementing this trait will enable anything /// that operates on tensors, including resetting, counting number of params, updating gradients, @@ -105,6 +105,29 @@ pub trait TensorCollection>: Sized { m: Default::default(), } } + + /// Creates a [ModuleFields] that represents hyperparamter tensor field. + /// + /// See also: [TensorField], [TensorCollection], [TensorOptions]. + fn hyperparameter( + name: &str, + get_ref: F1, + get_mut: F2, + options: HyperparameterOptions, + ) -> HyperparameterField + where + F1: FnMut(&Self) -> &N, + F2: FnMut(&mut Self) -> &mut N, + N: num_traits::ToPrimitive, + { + HyperparameterField { + name, + get_ref, + get_mut, + options, + m: Default::default(), + } + } } /// An object that can visit [TensorCollection]s and [Tensor]s recursively. @@ -137,6 +160,18 @@ pub trait ModuleVisitor, E: Dtype, D: Device>: Size GetRef: FnMut(&T) -> &Tensor, GetMut: FnMut(&mut T) -> &mut Tensor; + fn visit_hyperparameter( + &mut self, + name: &str, + get_refs: GetRef, + get_muts: GetMut, + opts: HyperparameterOptions, + ) -> Result, Self::Err> + where + N: num_traits::ToPrimitive, + GetRef: FnMut(&T) -> &N, + GetMut: FnMut(&mut T) -> &mut N; + /// Takes something that implements [ModuleFields] and function that takes /// [ModuleFields::Output] and returns an instance of T. fn visit_fields>( @@ -229,3 +264,17 @@ impl> TensorOptions { } } } + +/// Options to change behavior of [ModuleVisitor] +#[non_exhaustive] +pub struct HyperparameterOptions { + /// The default value for this parameter + pub default: N, +} + +impl HyperparameterOptions { + // Constructs a HyperparameterOptions using the parameter's default value + pub fn from_default(default: N) -> Self { + Self { default } + } +} diff --git a/src/nn/tensor_collection/mod.rs b/src/nn/tensor_collection/mod.rs index dbfe10459..657265347 100644 --- a/src/nn/tensor_collection/mod.rs +++ b/src/nn/tensor_collection/mod.rs @@ -6,8 +6,8 @@ mod collection; mod visitor; mod visitor_impls; -pub use collection::{ModuleVisitor, TensorCollection, TensorOptions}; +pub use collection::{HyperparameterOptions, ModuleVisitor, TensorCollection, TensorOptions}; pub use visitor::{ - ModuleField, ModuleFields, RecursiveWalker, TensorField, TensorViewer, TensorVisitor, - ViewTensorMut, ViewTensorName, ViewTensorRef, + HyperparameterField, ModuleField, ModuleFields, RecursiveWalker, TensorField, TensorViewer, + TensorVisitor, ViewTensorMut, ViewTensorName, ViewTensorRef, }; diff --git a/src/nn/tensor_collection/visitor.rs b/src/nn/tensor_collection/visitor.rs index 289eff6a7..62d99b23d 100644 --- a/src/nn/tensor_collection/visitor.rs +++ b/src/nn/tensor_collection/visitor.rs @@ -4,7 +4,7 @@ use crate::{ tensor_ops::Device, }; -use super::{ModuleVisitor, TensorCollection, TensorOptions}; +use super::{HyperparameterOptions, ModuleVisitor, TensorCollection, TensorOptions}; /// A standard [ModuleVisitor] that executes `F` on every [Tensor] encountered. /// `F` must implement [TensorVisitor] @@ -77,6 +77,14 @@ pub trait TensorVisitor> { opts: TensorOptions, t: ::View<'_, Tensor>, ) -> Result>, Self::Err>; + + fn visit_hyperparameter( + &mut self, + opts: HyperparameterOptions, + _h: ::View<'_, N>, + ) -> Result, Self::Err> { + Ok(Some(opts.default)) + } } /// Something that can view [Tensor]s in different ways. For example @@ -110,7 +118,7 @@ pub trait ModuleFields, E: Dtype, D: Device> { /// and returns optionally constructed fields fn visit_fields>( self, - module: &mut V, + visitor: &mut V, ) -> Result, V::Err>; /// If any optional fields are None, returns None. Otherwise returns instances of all fields. @@ -145,6 +153,20 @@ where pub(super) m: std::marker::PhantomData, } +/// A [ModuleFields] that represents a field that contains single number which should be serialized. +pub struct HyperparameterField<'a, F1, F2, Mod, N> +where + N: num_traits::ToPrimitive, + F1: FnMut(&Mod) -> &N, + F2: FnMut(&mut Mod) -> &mut N, +{ + pub(super) name: &'a str, + pub(super) get_ref: F1, + pub(super) get_mut: F2, + pub(super) options: HyperparameterOptions, + pub(super) m: std::marker::PhantomData, +} + /// A [TensorViewer] that represents a `&Tensor` #[derive(Debug)] pub enum ViewTensorRef {} diff --git a/src/nn/tensor_collection/visitor_impls.rs b/src/nn/tensor_collection/visitor_impls.rs index 48612a7f0..1f0a03a03 100644 --- a/src/nn/tensor_collection/visitor_impls.rs +++ b/src/nn/tensor_collection/visitor_impls.rs @@ -53,6 +53,24 @@ impl<'a, T: TensorCollection, E: Dtype, D: Device, F: TensorVisitor( + &mut self, + name: &str, + mut get_refs: GetRef, + mut get_muts: GetMut, + opts: HyperparameterOptions, + ) -> Result, Self::Err> + where + N: num_traits::ToPrimitive, + GetRef: FnMut(&T) -> &N, + GetMut: FnMut(&mut T) -> &mut N, + { + self.f.visit_hyperparameter( + opts, + F::Viewer::view_field(&mut self.m, name, &mut get_refs, &mut get_muts), + ) + } + fn visit_fields>( &mut self, fields: M, @@ -219,6 +237,30 @@ where } } +impl<'a, F1, F2, Mod, N, E: Dtype, D: Device> ModuleFields for HyperparameterField<'a, F1, F2, Mod, N> +where + N: num_traits::ToPrimitive, + F1: FnMut(&Mod) -> &N, + F2: FnMut(&mut Mod) -> &mut N, + Mod: TensorCollection, +{ + type Options> = Option; + type Output> = N; + + fn visit_fields>( + self, + visitor: &mut V, + ) -> Result, V::Err> { + visitor.visit_hyperparameter(self.name, self.get_ref, self.get_mut, self.options) + } + + fn handle_options>( + options: Self::Options, + ) -> Option> { + options + } +} + impl, Mod: TensorCollection, E: Dtype, D: Device> ModuleFields for Vec { @@ -227,12 +269,12 @@ impl, Mod: TensorCollection, E: Dtype, D: Devic fn visit_fields>( self, - module: &mut V, + visitor: &mut V, ) -> Result, V::Err> { let mut out = Vec::with_capacity(self.len()); for x in self { - out.push(x.visit_fields(module)?); + out.push(x.visit_fields(visitor)?); } Ok(out) @@ -255,7 +297,7 @@ impl, E: Dtype, D: Device> ModuleFields> = (); type Output> = (); - fn visit_fields>(self, _module: &mut V) -> Result<(), V::Err> { + fn visit_fields>(self, _visitor: &mut V) -> Result<(), V::Err> { Ok(()) } @@ -270,7 +312,7 @@ macro_rules! tuple_impls { type View<'a, Mod: 'a> = ($($name::View<'a, Mod>,)+); fn view_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::View<'_, Mod>, + visitor: &'a mut Self::View<'_, Mod>, name: &str, get_ref: &mut GetRef, get_mut: &mut GetMut, @@ -279,7 +321,7 @@ macro_rules! tuple_impls { GetRef: FnMut(&Mod) -> &Field, GetMut: FnMut(&mut Mod) -> &mut Field, { - ($($name::view_field(&mut module.$idx, name, get_ref, get_mut),)+) + ($($name::view_field(&mut visitor.$idx, name, get_ref, get_mut),)+) } } diff --git a/src/shapes/shape.rs b/src/shapes/shape.rs index ec16dbc9e..a0130726b 100644 --- a/src/shapes/shape.rs +++ b/src/shapes/shape.rs @@ -70,6 +70,7 @@ pub trait Dtype: + std::ops::MulAssign + std::ops::DivAssign + num_traits::FromPrimitive + + num_traits::ToPrimitive { } impl Dtype for f32 {} diff --git a/src/tensor/numpy.rs b/src/tensor/numpy.rs index 68f32f0c8..d9c245c21 100644 --- a/src/tensor/numpy.rs +++ b/src/tensor/numpy.rs @@ -41,7 +41,7 @@ impl, T> Tensor { } let mut f = r .by_name(&filename) - .expect(&std::format!("'{}' not found", filename)); + .unwrap_or_else(|_| panic!("'{filename}' not found")); self.read_from(&mut f)?; Ok(()) } From 72f8b1a198fd790bc09e5a82552875b293aa4a73 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 17:59:42 -0500 Subject: [PATCH 02/10] run cargo fmt --- src/nn/tensor_collection/collection.rs | 2 +- src/nn/tensor_collection/visitor_impls.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs index bb8ec6def..c256b003c 100644 --- a/src/nn/tensor_collection/collection.rs +++ b/src/nn/tensor_collection/collection.rs @@ -6,7 +6,7 @@ use crate::{ tensor_ops::Device, }; -use super::{ModuleField, ModuleFields, TensorField, HyperparameterField}; +use super::{HyperparameterField, ModuleField, ModuleFields, TensorField}; /// A collection of named tensors. Implementing this trait will enable anything /// that operates on tensors, including resetting, counting number of params, updating gradients, diff --git a/src/nn/tensor_collection/visitor_impls.rs b/src/nn/tensor_collection/visitor_impls.rs index 1f0a03a03..36857e075 100644 --- a/src/nn/tensor_collection/visitor_impls.rs +++ b/src/nn/tensor_collection/visitor_impls.rs @@ -237,7 +237,8 @@ where } } -impl<'a, F1, F2, Mod, N, E: Dtype, D: Device> ModuleFields for HyperparameterField<'a, F1, F2, Mod, N> +impl<'a, F1, F2, Mod, N, E: Dtype, D: Device> ModuleFields + for HyperparameterField<'a, F1, F2, Mod, N> where N: num_traits::ToPrimitive, F1: FnMut(&Mod) -> &N, From cd58811716f359cb249679c6175622d88460c21f Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 18:12:21 -0500 Subject: [PATCH 03/10] use 'scalar' instead of 'hyperparameter'; use NumCast instead of ToPrimitive --- src/nn/tensor_collection/collection.rs | 25 ++++++++++++----------- src/nn/tensor_collection/mod.rs | 4 ++-- src/nn/tensor_collection/visitor.rs | 14 +++++++------ src/nn/tensor_collection/visitor_impls.rs | 16 ++++++++------- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs index c256b003c..58df9ee3d 100644 --- a/src/nn/tensor_collection/collection.rs +++ b/src/nn/tensor_collection/collection.rs @@ -1,4 +1,5 @@ #![allow(clippy::type_complexity)] +use num_traits::NumCast; use crate::{ shapes::{ConstShape, Dtype, Shape}, @@ -6,7 +7,7 @@ use crate::{ tensor_ops::Device, }; -use super::{HyperparameterField, ModuleField, ModuleFields, TensorField}; +use super::{ScalarField, ModuleField, ModuleFields, TensorField}; /// A collection of named tensors. Implementing this trait will enable anything /// that operates on tensors, including resetting, counting number of params, updating gradients, @@ -109,18 +110,18 @@ pub trait TensorCollection>: Sized { /// Creates a [ModuleFields] that represents hyperparamter tensor field. /// /// See also: [TensorField], [TensorCollection], [TensorOptions]. - fn hyperparameter( + fn scalar( name: &str, get_ref: F1, get_mut: F2, - options: HyperparameterOptions, - ) -> HyperparameterField + options: ScalarOptions, + ) -> ScalarField where F1: FnMut(&Self) -> &N, F2: FnMut(&mut Self) -> &mut N, - N: num_traits::ToPrimitive, + N: NumCast, { - HyperparameterField { + ScalarField { name, get_ref, get_mut, @@ -160,15 +161,15 @@ pub trait ModuleVisitor, E: Dtype, D: Device>: Size GetRef: FnMut(&T) -> &Tensor, GetMut: FnMut(&mut T) -> &mut Tensor; - fn visit_hyperparameter( + fn visit_scalar( &mut self, name: &str, get_refs: GetRef, get_muts: GetMut, - opts: HyperparameterOptions, + opts: ScalarOptions, ) -> Result, Self::Err> where - N: num_traits::ToPrimitive, + N: NumCast, GetRef: FnMut(&T) -> &N, GetMut: FnMut(&mut T) -> &mut N; @@ -267,13 +268,13 @@ impl> TensorOptions { /// Options to change behavior of [ModuleVisitor] #[non_exhaustive] -pub struct HyperparameterOptions { +pub struct ScalarOptions { /// The default value for this parameter pub default: N, } -impl HyperparameterOptions { - // Constructs a HyperparameterOptions using the parameter's default value +impl ScalarOptions { + // Constructs a ScalarOptions using the parameter's default value pub fn from_default(default: N) -> Self { Self { default } } diff --git a/src/nn/tensor_collection/mod.rs b/src/nn/tensor_collection/mod.rs index 657265347..ba4edad4a 100644 --- a/src/nn/tensor_collection/mod.rs +++ b/src/nn/tensor_collection/mod.rs @@ -6,8 +6,8 @@ mod collection; mod visitor; mod visitor_impls; -pub use collection::{HyperparameterOptions, ModuleVisitor, TensorCollection, TensorOptions}; +pub use collection::{ScalarOptions, ModuleVisitor, TensorCollection, TensorOptions}; pub use visitor::{ - HyperparameterField, ModuleField, ModuleFields, RecursiveWalker, TensorField, TensorViewer, + ScalarField, ModuleField, ModuleFields, RecursiveWalker, TensorField, TensorViewer, TensorVisitor, ViewTensorMut, ViewTensorName, ViewTensorRef, }; diff --git a/src/nn/tensor_collection/visitor.rs b/src/nn/tensor_collection/visitor.rs index 62d99b23d..6c85bc55c 100644 --- a/src/nn/tensor_collection/visitor.rs +++ b/src/nn/tensor_collection/visitor.rs @@ -1,10 +1,12 @@ +use num_traits::NumCast; + use crate::{ shapes::{Dtype, Shape}, tensor::Tensor, tensor_ops::Device, }; -use super::{HyperparameterOptions, ModuleVisitor, TensorCollection, TensorOptions}; +use super::{ScalarOptions, ModuleVisitor, TensorCollection, TensorOptions}; /// A standard [ModuleVisitor] that executes `F` on every [Tensor] encountered. /// `F` must implement [TensorVisitor] @@ -78,9 +80,9 @@ pub trait TensorVisitor> { t: ::View<'_, Tensor>, ) -> Result>, Self::Err>; - fn visit_hyperparameter( + fn visit_scalar( &mut self, - opts: HyperparameterOptions, + opts: ScalarOptions, _h: ::View<'_, N>, ) -> Result, Self::Err> { Ok(Some(opts.default)) @@ -154,16 +156,16 @@ where } /// A [ModuleFields] that represents a field that contains single number which should be serialized. -pub struct HyperparameterField<'a, F1, F2, Mod, N> +pub struct ScalarField<'a, F1, F2, Mod, N> where - N: num_traits::ToPrimitive, + N: NumCast, F1: FnMut(&Mod) -> &N, F2: FnMut(&mut Mod) -> &mut N, { pub(super) name: &'a str, pub(super) get_ref: F1, pub(super) get_mut: F2, - pub(super) options: HyperparameterOptions, + pub(super) options: ScalarOptions, pub(super) m: std::marker::PhantomData, } diff --git a/src/nn/tensor_collection/visitor_impls.rs b/src/nn/tensor_collection/visitor_impls.rs index 36857e075..4fcb47496 100644 --- a/src/nn/tensor_collection/visitor_impls.rs +++ b/src/nn/tensor_collection/visitor_impls.rs @@ -1,3 +1,5 @@ +use num_traits::NumCast; + use std::{ string::{String, ToString}, vec::Vec, @@ -53,19 +55,19 @@ impl<'a, T: TensorCollection, E: Dtype, D: Device, F: TensorVisitor( + fn visit_scalar( &mut self, name: &str, mut get_refs: GetRef, mut get_muts: GetMut, - opts: HyperparameterOptions, + opts: ScalarOptions, ) -> Result, Self::Err> where - N: num_traits::ToPrimitive, + N: NumCast, GetRef: FnMut(&T) -> &N, GetMut: FnMut(&mut T) -> &mut N, { - self.f.visit_hyperparameter( + self.f.visit_scalar( opts, F::Viewer::view_field(&mut self.m, name, &mut get_refs, &mut get_muts), ) @@ -238,9 +240,9 @@ where } impl<'a, F1, F2, Mod, N, E: Dtype, D: Device> ModuleFields - for HyperparameterField<'a, F1, F2, Mod, N> + for ScalarField<'a, F1, F2, Mod, N> where - N: num_traits::ToPrimitive, + N: NumCast, F1: FnMut(&Mod) -> &N, F2: FnMut(&mut Mod) -> &mut N, Mod: TensorCollection, @@ -252,7 +254,7 @@ where self, visitor: &mut V, ) -> Result, V::Err> { - visitor.visit_hyperparameter(self.name, self.get_ref, self.get_mut, self.options) + visitor.visit_scalar(self.name, self.get_ref, self.get_mut, self.options) } fn handle_options>( From 07f63cc81fd02e8375f0aa95cc285a8fae3b5ce6 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 18:13:24 -0500 Subject: [PATCH 04/10] run cargo fmt --- src/nn/tensor_collection/collection.rs | 2 +- src/nn/tensor_collection/mod.rs | 4 ++-- src/nn/tensor_collection/visitor.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs index 58df9ee3d..aa08b454d 100644 --- a/src/nn/tensor_collection/collection.rs +++ b/src/nn/tensor_collection/collection.rs @@ -7,7 +7,7 @@ use crate::{ tensor_ops::Device, }; -use super::{ScalarField, ModuleField, ModuleFields, TensorField}; +use super::{ModuleField, ModuleFields, ScalarField, TensorField}; /// A collection of named tensors. Implementing this trait will enable anything /// that operates on tensors, including resetting, counting number of params, updating gradients, diff --git a/src/nn/tensor_collection/mod.rs b/src/nn/tensor_collection/mod.rs index ba4edad4a..b316ea9c4 100644 --- a/src/nn/tensor_collection/mod.rs +++ b/src/nn/tensor_collection/mod.rs @@ -6,8 +6,8 @@ mod collection; mod visitor; mod visitor_impls; -pub use collection::{ScalarOptions, ModuleVisitor, TensorCollection, TensorOptions}; +pub use collection::{ModuleVisitor, ScalarOptions, TensorCollection, TensorOptions}; pub use visitor::{ - ScalarField, ModuleField, ModuleFields, RecursiveWalker, TensorField, TensorViewer, + ModuleField, ModuleFields, RecursiveWalker, ScalarField, TensorField, TensorViewer, TensorVisitor, ViewTensorMut, ViewTensorName, ViewTensorRef, }; diff --git a/src/nn/tensor_collection/visitor.rs b/src/nn/tensor_collection/visitor.rs index 6c85bc55c..466d17ec8 100644 --- a/src/nn/tensor_collection/visitor.rs +++ b/src/nn/tensor_collection/visitor.rs @@ -6,7 +6,7 @@ use crate::{ tensor_ops::Device, }; -use super::{ScalarOptions, ModuleVisitor, TensorCollection, TensorOptions}; +use super::{ModuleVisitor, ScalarOptions, TensorCollection, TensorOptions}; /// A standard [ModuleVisitor] that executes `F` on every [Tensor] encountered. /// `F` must implement [TensorVisitor] From bf8396689ccbf944abb6543933accfa31b52fd15 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 18:32:54 -0500 Subject: [PATCH 05/10] add scalars to TensorCollection impls for examples listed in #485 --- src/nn/batchnorm1d.rs | 18 +++++++++++++++--- src/nn/batchnorm2d.rs | 18 +++++++++++++++--- src/nn/dropout.rs | 22 +++++++++++++++++++++- src/nn/layer_norm.rs | 10 ++++++++-- 4 files changed, 59 insertions(+), 9 deletions(-) diff --git a/src/nn/batchnorm1d.rs b/src/nn/batchnorm1d.rs index 848fed421..df0d9fb54 100644 --- a/src/nn/batchnorm1d.rs +++ b/src/nn/batchnorm1d.rs @@ -199,14 +199,26 @@ impl> TensorCollection for BatchNor |s| &mut s.running_var, TensorOptions::detached(|t| t.try_fill_with_ones()), ), + Self::scalar( + "epsilon", + |s| &s.epsilon, + |s| &mut s.epsilon, + ScalarOptions::from_default(1e-5), + ), + Self::scalar( + "momentum", + |s| &s.momentum, + |s| &mut s.momentum, + ScalarOptions::from_default(0.1), + ), ), - |(scale, bias, running_mean, running_var)| BatchNorm1D { + |(scale, bias, running_mean, running_var, epsilon, momentum)| BatchNorm1D { scale, bias, running_mean, running_var, - epsilon: 1e-5, - momentum: 0.1, + epsilon, + momentum, }, ) } diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index 4f74b701a..400da5590 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -272,14 +272,26 @@ impl> TensorCollection for BatchNor |s| &mut s.running_var, TensorOptions::detached(|t| t.try_fill_with_ones()), ), + Self::scalar( + "epsilon", + |s| &s.epsilon, + |s| &mut s.epsilon, + ScalarOptions::from_default(1e-5), + ), + Self::scalar( + "momentum", + |s| &s.momentum, + |s| &mut s.momentum, + ScalarOptions::from_default(0.1), + ), ), - |(scale, bias, running_mean, running_var)| BatchNorm2D { + |(scale, bias, running_mean, running_var, epsilon, momentum)| BatchNorm2D { scale, bias, running_mean, running_var, - epsilon: 1e-5, - momentum: 0.1, + epsilon, + momentum, }, ) } diff --git a/src/nn/dropout.rs b/src/nn/dropout.rs index 018c36d35..329221cf5 100644 --- a/src/nn/dropout.rs +++ b/src/nn/dropout.rs @@ -129,7 +129,27 @@ impl Default for Dropout { } } -impl ZeroSizedModule for Dropout {} +impl, E: Dtype> BuildOnDevice for Dropout { + type Built = Dropout; +} + +impl> TensorCollection for Dropout { + type To> = Dropout; + + fn iter_tensors>( + visitor: &mut V, + ) -> Result>, V::Err> { + visitor.visit_fields( + >::scalar( + "p", + |s| &s.p, + |s| &mut s.p, + ScalarOptions::from_default(0.5), + ), + |p| Dropout { p } + ) + } +} impl> Module> for Dropout { type Output = Tensor; diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index e867b8704..5d197795e 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -64,11 +64,17 @@ impl> TensorCollection for LayerNor |s| &mut s.beta, TensorOptions::reset_to_zeros(), ), + Self::scalar( + "epsilon", + |s| &s.epsilon, + |s| &mut s.epsilon, + ScalarOptions::from_default(1e-5), + ), ), - |(gamma, beta)| LayerNorm1D { + |(gamma, beta, epsilon)| LayerNorm1D { gamma, beta, - epsilon: 1e-5, + epsilon, }, ) } From d2dd8ba0632498aea95ee5ed4ac45235f67352ba Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 18:33:18 -0500 Subject: [PATCH 06/10] run cargo fmt --- src/nn/dropout.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nn/dropout.rs b/src/nn/dropout.rs index 329221cf5..7ff12553b 100644 --- a/src/nn/dropout.rs +++ b/src/nn/dropout.rs @@ -146,7 +146,7 @@ impl> TensorCollection for Dropout { |s| &mut s.p, ScalarOptions::from_default(0.5), ), - |p| Dropout { p } + |p| Dropout { p }, ) } } From 1b51fabcc39e92992fa7fd776e33da21861e3d59 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 20:30:10 -0500 Subject: [PATCH 07/10] implement npz scalar serialization --- src/nn/npz.rs | 30 ++++++++++++++- src/tensor/numpy.rs | 89 ++++++++++++++++++++++++++++++++++----------- 2 files changed, 97 insertions(+), 22 deletions(-) diff --git a/src/nn/npz.rs b/src/nn/npz.rs index cc57e63dd..3ce1d207a 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -1,7 +1,7 @@ use crate::{ shapes::{Dtype, Shape}, tensor::{ - numpy::{NpzError, NumpyDtype}, + numpy::{read_from_npz, write_to_npz, NpzError, NumpyDtype}, Tensor, }, tensor_ops::Device, @@ -129,6 +129,18 @@ impl> TensorVisitor t.write_to_npz(self, full_path)?; Ok(None) } + + fn visit_scalar( + &mut self, + _opts: ScalarOptions, + (n, full_path): (&N, String), + ) -> Result, Self::Err> { + let n = n + .to_f64() + .unwrap_or_else(|| panic!("Failed to convert scalar value at {full_path} to f64!")); + write_to_npz(self, &[], &[n], full_path)?; + Ok(None) + } } impl> TensorVisitor @@ -147,6 +159,22 @@ impl> TensorVisitor t.read_from_npz(self, full_path)?; Ok(None) } + + fn visit_scalar( + &mut self, + _opts: ScalarOptions, + (n, full_path): (&mut N, String), + ) -> Result, Self::Err> { + let buf: Vec = read_from_npz(self, &[], full_path)?; + *n = N::from(buf[0]).unwrap_or_else(|| { + panic!( + "Failed to convert f64 value {} to {}!", + buf[0], + std::any::type_name::() + ) + }); + Ok(None) + } } #[cfg(test)] diff --git a/src/tensor/numpy.rs b/src/tensor/numpy.rs index d9c245c21..a213da591 100644 --- a/src/tensor/numpy.rs +++ b/src/tensor/numpy.rs @@ -58,34 +58,81 @@ impl, T> Tensor { self.write_to(&mut f) } - pub(crate) fn read_from(&mut self, r: &mut R) -> Result<(), NpyError> { - let endian = read_header::(r, self.shape().concrete().into_iter().collect())?; - let numel = self.shape().num_elements(); - let mut buf = Vec::with_capacity(numel); - for _ in 0..numel { - buf.push(E::read_endian(r, endian)?); - } - D::copy_from(self, &buf); + fn read_from(&mut self, r: &mut R) -> Result<(), NpyError> { + let buf = read_from_npy(r, self.shape().concrete().as_ref())?; + self.copy_from(&buf); Ok(()) } - pub(crate) fn write_to(&self, w: &mut W) -> io::Result<()> { - let endian = Endian::Little; - write_header::(w, endian, self.shape().concrete().into_iter().collect())?; - let numel = self.shape().num_elements(); - let mut buf = std::vec![Default::default(); numel]; - D::copy_into(self, &mut buf); - for v in buf.iter() { - v.write_endian(w, endian)?; - } - Ok(()) + fn write_to(&self, w: &mut W) -> io::Result<()> { + let buf = self.as_vec(); + write_to_npy(w, self.shape().concrete().as_ref(), &buf) + } +} + +pub(crate) fn write_to_npz( + w: &mut zip::ZipWriter, + shape: &[usize], + data: &[E], + mut filename: String, +) -> io::Result<()> { + if !filename.ends_with(".npy") { + filename.push_str(".npy"); } + w.start_file(filename, Default::default())?; + write_to_npy(w, shape, data) +} + +pub(crate) fn read_from_npz( + r: &mut zip::ZipArchive, + shape: &[usize], + mut filename: String, +) -> Result, NpyError> { + if !filename.ends_with(".npy") { + filename.push_str(".npy"); + } + + let mut f = r + .by_name(&filename) + .unwrap_or_else(|_| panic!("'{filename}' not found")); + + read_from_npy(&mut f, shape) +} + +pub(crate) fn write_to_npy( + w: &mut W, + shape: &[usize], + data: &[E], +) -> io::Result<()> { + let endian = Endian::Little; + write_header::(w, endian, shape)?; + + for v in data.iter() { + v.write_endian(w, endian)?; + } + + Ok(()) +} + +pub(crate) fn read_from_npy( + r: &mut R, + shape: &[usize], +) -> Result, NpyError> { + let endian = read_header::(r, shape)?; + let numel = shape.iter().product::(); + let mut out = Vec::new(); + + for _ in 0..numel { + out.push(E::read_endian(r, endian)?); + } + + Ok(out) } fn write_header( w: &mut W, endian: Endian, - shape: Vec, + shape: &[usize], ) -> io::Result<()> { let shape_str = to_shape_str(shape); @@ -121,7 +168,7 @@ fn write_header( Ok(()) } -fn read_header(r: &mut R, shape: Vec) -> Result { +fn read_header(r: &mut R, shape: &[usize]) -> Result { let mut magic = [0; 6]; r.read_exact(&mut magic)?; if magic != MAGIC_NUMBER { @@ -311,7 +358,7 @@ impl From for NpyError { } } -fn to_shape_str(shape: Vec) -> String { +fn to_shape_str(shape: &[usize]) -> String { shape .iter() .map(|v| v.to_string()) From 6f0092a6a72050070fa7519ca91a39bb8246221d Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 20:58:27 -0500 Subject: [PATCH 08/10] fix documentation; simplify tensor/numpy.rs --- src/nn/npz.rs | 2 +- src/nn/tensor_collection/collection.rs | 2 +- src/nn/tensor_collection/visitor.rs | 2 +- src/tensor/numpy.rs | 36 +++++++------------------- 4 files changed, 13 insertions(+), 29 deletions(-) diff --git a/src/nn/npz.rs b/src/nn/npz.rs index 3ce1d207a..ae0cb574c 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -168,7 +168,7 @@ impl> TensorVisitor let buf: Vec = read_from_npz(self, &[], full_path)?; *n = N::from(buf[0]).unwrap_or_else(|| { panic!( - "Failed to convert f64 value {} to {}!", + "Failed to convert f64 value {} to {} when reading from npz!", buf[0], std::any::type_name::() ) diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs index aa08b454d..8617ffeff 100644 --- a/src/nn/tensor_collection/collection.rs +++ b/src/nn/tensor_collection/collection.rs @@ -107,7 +107,7 @@ pub trait TensorCollection>: Sized { } } - /// Creates a [ModuleFields] that represents hyperparamter tensor field. + /// Creates a [ModuleFields] that represents a scalar field. /// /// See also: [TensorField], [TensorCollection], [TensorOptions]. fn scalar( diff --git a/src/nn/tensor_collection/visitor.rs b/src/nn/tensor_collection/visitor.rs index 466d17ec8..c8bdbd0e2 100644 --- a/src/nn/tensor_collection/visitor.rs +++ b/src/nn/tensor_collection/visitor.rs @@ -155,7 +155,7 @@ where pub(super) m: std::marker::PhantomData, } -/// A [ModuleFields] that represents a field that contains single number which should be serialized. +/// A [ModuleFields] that represents a field that contains a scalar value that should be serialized. pub struct ScalarField<'a, F1, F2, Mod, N> where N: NumCast, diff --git a/src/tensor/numpy.rs b/src/tensor/numpy.rs index a213da591..9d6ac77cb 100644 --- a/src/tensor/numpy.rs +++ b/src/tensor/numpy.rs @@ -20,13 +20,10 @@ impl, T> Tensor { pub fn write_to_npz( &self, w: &mut zip::ZipWriter, - mut filename: String, + filename: String, ) -> ZipResult<()> { - if !filename.ends_with(".npy") { - filename.push_str(".npy"); - } - w.start_file(filename, Default::default())?; - self.write_to(w)?; + let buf = self.as_vec(); + write_to_npz(w, self.shape().concrete().as_ref(), &buf, filename)?; Ok(()) } @@ -34,39 +31,26 @@ impl, T> Tensor { pub fn read_from_npz( &mut self, r: &mut zip::ZipArchive, - mut filename: String, + filename: String, ) -> Result<(), NpzError> { - if !filename.ends_with(".npy") { - filename.push_str(".npy"); - } - let mut f = r - .by_name(&filename) - .unwrap_or_else(|_| panic!("'{filename}' not found")); - self.read_from(&mut f)?; + let buf = read_from_npz(r, self.shape().concrete().as_ref(), filename)?; + self.copy_from(&buf); Ok(()) } /// Attemps to load the data from a `.npy` file at `path` pub fn load_from_npy>(&mut self, path: P) -> Result<(), NpyError> { let mut f = BufReader::new(File::open(path)?); - self.read_from(&mut f) + let buf = read_from_npy(&mut f, self.shape().concrete().as_ref())?; + self.copy_from(&buf); + Ok(()) } /// Saves the tensor to a `.npy` file located at `path` pub fn save_to_npy>(&self, path: P) -> io::Result<()> { let mut f = BufWriter::new(File::create(path)?); - self.write_to(&mut f) - } - - fn read_from(&mut self, r: &mut R) -> Result<(), NpyError> { - let buf = read_from_npy(r, self.shape().concrete().as_ref())?; - self.copy_from(&buf); - Ok(()) - } - - fn write_to(&self, w: &mut W) -> io::Result<()> { let buf = self.as_vec(); - write_to_npy(w, self.shape().concrete().as_ref(), &buf) + write_to_npy(&mut f, self.shape().concrete().as_ref(), &buf) } } From c701cace346f8aa4eca628a369fbe63302ec40c0 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 28 Jun 2023 23:55:54 -0500 Subject: [PATCH 09/10] add scalars to safetensor serialization --- src/nn/safetensors.rs | 48 ++++++++++++++++++++++++++++++++++++------- src/tensor/numpy.rs | 4 ++-- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/nn/safetensors.rs b/src/nn/safetensors.rs index f6acf02f7..c552369cd 100644 --- a/src/nn/safetensors.rs +++ b/src/nn/safetensors.rs @@ -12,7 +12,6 @@ use safetensors::{ tensor::{Dtype as SDtype, SafeTensors, TensorView}, SafeTensorError, }; -use std::collections::BTreeMap; use super::tensor_collection::*; @@ -25,12 +24,12 @@ struct TensorData { } pub struct Writer { - tensors: BTreeMap, + tensors: Vec<(String, TensorData)>, } impl Writer { pub fn new() -> Self { - let tensors = BTreeMap::new(); + let tensors = Vec::new(); Self { tensors } } @@ -44,11 +43,11 @@ impl Writer { let data = tensor.as_vec(); let data: Vec = data.iter().flat_map(|f| f.to_le_bytes()).collect(); let tdata = TensorData { dtype, shape, data }; - self.tensors.insert(key, tdata); + self.tensors.push((key, tdata)); } pub fn save(&self, path: &Path) -> Result<(), SafeTensorError> { - let views: BTreeMap = self + let (names, views): (Vec, Vec) = self .tensors .iter() .map(|(k, tensor)| { @@ -57,8 +56,11 @@ impl Writer { TensorView::new(tensor.dtype, tensor.shape.clone(), &tensor.data).unwrap(), ) }) - .collect(); - serialize_to_file(&views, &None, path) + .unzip(); + + let data = names.into_iter().zip(views.iter()); + + serialize_to_file(data, &None, path) } } @@ -76,6 +78,20 @@ impl> TensorVisitor for Writer { self.add(full_path, t); Ok(None) } + + fn visit_scalar( + &mut self, + _: ScalarOptions, + (n, full_path): (&N, String), + ) -> Result, Self::Err> { + let data = TensorData { + dtype: safetensors::Dtype::F64, + shape: Vec::new(), + data: n.to_f64().unwrap_or_else(|| panic!("Failed to convert scalar value at {full_path} to f64!")).to_le_bytes().to_vec(), + }; + self.tensors.push((full_path, data)); + Ok(None) + } } /// Something that can be saved to a `.safetensors`. @@ -148,6 +164,24 @@ impl<'data, E: Dtype + SafeDtype, D: Device> TensorVisitor for SafeTens t.load_safetensor(self, &full_path)?; Ok(None) } + + fn visit_scalar( + &mut self, + _: ScalarOptions, + (n, full_path): (&mut N, String), + ) -> Result, Self::Err> { + let data = self.tensor(&full_path)?.data(); + let mut array = [0; 8]; + array.copy_from_slice(data); + let val = f64::from_le_bytes(array); + *n = N::from(val).unwrap_or_else(|| { + panic!( + "Failed to convert f64 value {val} at {full_path} to {} when reading from safetensors!", + std::any::type_name::() + ) + }); + Ok(None) + } } #[cfg(test)] diff --git a/src/tensor/numpy.rs b/src/tensor/numpy.rs index 9d6ac77cb..a75404f3d 100644 --- a/src/tensor/numpy.rs +++ b/src/tensor/numpy.rs @@ -83,7 +83,7 @@ pub(crate) fn read_from_npz( read_from_npy(&mut f, shape) } -pub(crate) fn write_to_npy( +fn write_to_npy( w: &mut W, shape: &[usize], data: &[E], @@ -98,7 +98,7 @@ pub(crate) fn write_to_npy( Ok(()) } -pub(crate) fn read_from_npy( +fn read_from_npy( r: &mut R, shape: &[usize], ) -> Result, NpyError> { From 0dbb1e07abe997abefea11e9b6d71d92dcadf158 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Thu, 29 Jun 2023 00:07:12 -0500 Subject: [PATCH 10/10] small simplification --- src/nn/safetensors.rs | 6 +++++- src/tensor/safetensors.rs | 18 ++---------------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/nn/safetensors.rs b/src/nn/safetensors.rs index c552369cd..46f04a80d 100644 --- a/src/nn/safetensors.rs +++ b/src/nn/safetensors.rs @@ -87,7 +87,11 @@ impl> TensorVisitor for Writer { let data = TensorData { dtype: safetensors::Dtype::F64, shape: Vec::new(), - data: n.to_f64().unwrap_or_else(|| panic!("Failed to convert scalar value at {full_path} to f64!")).to_le_bytes().to_vec(), + data: n + .to_f64() + .unwrap_or_else(|| panic!("Failed to convert scalar value at {full_path} to f64!")) + .to_le_bytes() + .to_vec(), }; self.tensors.push((full_path, data)); Ok(None) diff --git a/src/tensor/safetensors.rs b/src/tensor/safetensors.rs index 13a119783..bd717cad9 100644 --- a/src/tensor/safetensors.rs +++ b/src/tensor/safetensors.rs @@ -14,12 +14,7 @@ pub trait SafeDtype: Sized { impl SafeDtype for f32 { type Array = [u8; 4]; fn from_le_bytes(bytes: &[u8], index: usize) -> Self { - Self::from_le_bytes([ - bytes[index], - bytes[index + 1], - bytes[index + 2], - bytes[index + 3], - ]) + Self::from_le_bytes(bytes[index..index + 4].try_into().unwrap()) } fn to_le_bytes(self) -> Self::Array { @@ -34,16 +29,7 @@ impl SafeDtype for f32 { impl SafeDtype for f64 { type Array = [u8; 8]; fn from_le_bytes(bytes: &[u8], index: usize) -> Self { - Self::from_le_bytes([ - bytes[index], - bytes[index + 1], - bytes[index + 2], - bytes[index + 3], - bytes[index + 4], - bytes[index + 5], - bytes[index + 6], - bytes[index + 7], - ]) + Self::from_le_bytes(bytes[index..index + 8].try_into().unwrap()) } fn safe_dtype() -> SDtype {