Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scalar support to TensorCollection #799

Merged
merged 12 commits into from
Jul 10, 2023
18 changes: 15 additions & 3 deletions src/nn/batchnorm1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,26 @@ impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for BatchNor
|s| &mut s.running_var,
TensorOptions::detached(|t| t.try_fill_with_ones()),
),
Self::scalar(
"epsilon",
|s| &s.epsilon,
|s| &mut s.epsilon,
ScalarOptions::from_default(1e-5),
),
Self::scalar(
"momentum",
|s| &s.momentum,
|s| &mut s.momentum,
ScalarOptions::from_default(0.1),
),
),
|(scale, bias, running_mean, running_var)| BatchNorm1D {
|(scale, bias, running_mean, running_var, epsilon, momentum)| BatchNorm1D {
scale,
bias,
running_mean,
running_var,
epsilon: 1e-5,
momentum: 0.1,
epsilon,
momentum,
},
)
}
Expand Down
18 changes: 15 additions & 3 deletions src/nn/batchnorm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,26 @@ impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for BatchNor
|s| &mut s.running_var,
TensorOptions::detached(|t| t.try_fill_with_ones()),
),
Self::scalar(
"epsilon",
|s| &s.epsilon,
|s| &mut s.epsilon,
ScalarOptions::from_default(1e-5),
),
Self::scalar(
"momentum",
|s| &s.momentum,
|s| &mut s.momentum,
ScalarOptions::from_default(0.1),
),
),
|(scale, bias, running_mean, running_var)| BatchNorm2D {
|(scale, bias, running_mean, running_var, epsilon, momentum)| BatchNorm2D {
scale,
bias,
running_mean,
running_var,
epsilon: 1e-5,
momentum: 0.1,
epsilon,
momentum,
},
)
}
Expand Down
22 changes: 21 additions & 1 deletion src/nn/dropout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,27 @@ impl Default for Dropout {
}
}

impl ZeroSizedModule for Dropout {}
impl<D: Device<E>, E: Dtype> BuildOnDevice<D, E> for Dropout {
type Built = Dropout;
}

impl<E: Dtype, D: Device<E>> TensorCollection<E, D> for Dropout {
type To<E2: Dtype, D2: Device<E2>> = Dropout;

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_fields(
<Self as TensorCollection<E, D>>::scalar(
"p",
|s| &s.p,
|s| &mut s.p,
ScalarOptions::from_default(0.5),
),
|p| Dropout { p },
)
}
}

impl<S: Shape, E: Dtype, D: Device<E>> Module<Tensor<S, E, D, NoneTape>> for Dropout {
type Output = Tensor<S, E, D, NoneTape>;
Expand Down
10 changes: 8 additions & 2 deletions src/nn/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,17 @@ impl<const M: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for LayerNor
|s| &mut s.beta,
TensorOptions::reset_to_zeros(),
),
Self::scalar(
"epsilon",
|s| &s.epsilon,
|s| &mut s.epsilon,
ScalarOptions::from_default(1e-5),
),
),
|(gamma, beta)| LayerNorm1D {
|(gamma, beta, epsilon)| LayerNorm1D {
gamma,
beta,
epsilon: 1e-5,
epsilon,
},
)
}
Expand Down
30 changes: 29 additions & 1 deletion src/nn/npz.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
shapes::{Dtype, Shape},
tensor::{
numpy::{NpzError, NumpyDtype},
numpy::{read_from_npz, write_to_npz, NpzError, NumpyDtype},
Tensor,
},
tensor_ops::Device,
Expand Down Expand Up @@ -129,6 +129,18 @@ impl<W: Write + Seek, E: Dtype + NumpyDtype, D: Device<E>> TensorVisitor<E, D>
t.write_to_npz(self, full_path)?;
Ok(None)
}

fn visit_scalar<N: num_traits::NumCast>(
&mut self,
_opts: ScalarOptions<N>,
(n, full_path): (&N, String),
) -> Result<Option<N>, Self::Err> {
let n = n
.to_f64()
.unwrap_or_else(|| panic!("Failed to convert scalar value at {full_path} to f64!"));
write_to_npz(self, &[], &[n], full_path)?;
Ok(None)
}
}

impl<R: Read + Seek, E: Dtype + NumpyDtype, D: Device<E>> TensorVisitor<E, D>
Expand All @@ -147,6 +159,22 @@ impl<R: Read + Seek, E: Dtype + NumpyDtype, D: Device<E>> TensorVisitor<E, D>
t.read_from_npz(self, full_path)?;
Ok(None)
}

fn visit_scalar<N: num_traits::NumCast>(
&mut self,
_opts: ScalarOptions<N>,
(n, full_path): (&mut N, String),
) -> Result<Option<N>, Self::Err> {
let buf: Vec<f64> = read_from_npz(self, &[], full_path)?;
*n = N::from(buf[0]).unwrap_or_else(|| {
panic!(
"Failed to convert f64 value {} to {} when reading from npz!",
buf[0],
std::any::type_name::<N>()
)
});
Ok(None)
}
}

#[cfg(test)]
Expand Down
52 changes: 45 additions & 7 deletions src/nn/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use safetensors::{
tensor::{Dtype as SDtype, SafeTensors, TensorView},
SafeTensorError,
};
use std::collections::BTreeMap;

use super::tensor_collection::*;

Expand All @@ -25,12 +24,12 @@ struct TensorData {
}

pub struct Writer {
tensors: BTreeMap<String, TensorData>,
tensors: Vec<(String, TensorData)>,
}

impl Writer {
pub fn new() -> Self {
let tensors = BTreeMap::new();
let tensors = Vec::new();
Self { tensors }
}

Expand All @@ -44,11 +43,11 @@ impl Writer {
let data = tensor.as_vec();
let data: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let tdata = TensorData { dtype, shape, data };
self.tensors.insert(key, tdata);
self.tensors.push((key, tdata));
}

pub fn save(&self, path: &Path) -> Result<(), SafeTensorError> {
let views: BTreeMap<String, TensorView> = self
let (names, views): (Vec<String>, Vec<TensorView>) = self
.tensors
.iter()
.map(|(k, tensor)| {
Expand All @@ -57,8 +56,11 @@ impl Writer {
TensorView::new(tensor.dtype, tensor.shape.clone(), &tensor.data).unwrap(),
)
})
.collect();
serialize_to_file(&views, &None, path)
.unzip();

let data = names.into_iter().zip(views.iter());

serialize_to_file(data, &None, path)
}
}

Expand All @@ -76,6 +78,24 @@ impl<E: Dtype + SafeDtype, D: Device<E>> TensorVisitor<E, D> for Writer {
self.add(full_path, t);
Ok(None)
}

fn visit_scalar<N: num_traits::NumCast>(
&mut self,
_: ScalarOptions<N>,
(n, full_path): (&N, String),
) -> Result<Option<N>, Self::Err> {
let data = TensorData {
dtype: safetensors::Dtype::F64,
shape: Vec::new(),
data: n
.to_f64()
.unwrap_or_else(|| panic!("Failed to convert scalar value at {full_path} to f64!"))
.to_le_bytes()
.to_vec(),
};
self.tensors.push((full_path, data));
Ok(None)
}
}

/// Something that can be saved to a `.safetensors`.
Expand Down Expand Up @@ -148,6 +168,24 @@ impl<'data, E: Dtype + SafeDtype, D: Device<E>> TensorVisitor<E, D> for SafeTens
t.load_safetensor(self, &full_path)?;
Ok(None)
}

fn visit_scalar<N: num_traits::NumCast>(
&mut self,
_: ScalarOptions<N>,
(n, full_path): (&mut N, String),
) -> Result<Option<N>, Self::Err> {
let data = self.tensor(&full_path)?.data();
let mut array = [0; 8];
array.copy_from_slice(data);
let val = f64::from_le_bytes(array);
*n = N::from(val).unwrap_or_else(|| {
panic!(
"Failed to convert f64 value {val} at {full_path} to {} when reading from safetensors!",
std::any::type_name::<N>()
)
});
Ok(None)
}
}

#[cfg(test)]
Expand Down
52 changes: 51 additions & 1 deletion src/nn/tensor_collection/collection.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#![allow(clippy::type_complexity)]
use num_traits::NumCast;

use crate::{
shapes::{ConstShape, Dtype, Shape},
tensor::{OneFillStorage, Tensor, ZeroFillStorage},
tensor_ops::Device,
};

use super::{ModuleField, ModuleFields, TensorField};
use super::{ModuleField, ModuleFields, ScalarField, TensorField};

/// A collection of named tensors. Implementing this trait will enable anything
/// that operates on tensors, including resetting, counting number of params, updating gradients,
Expand Down Expand Up @@ -105,6 +106,29 @@ pub trait TensorCollection<E: Dtype, D: Device<E>>: Sized {
m: Default::default(),
}
}

/// Creates a [ModuleFields] that represents a scalar field.
///
/// See also: [TensorField], [TensorCollection], [TensorOptions].
fn scalar<F1, F2, N>(
name: &str,
get_ref: F1,
get_mut: F2,
options: ScalarOptions<N>,
) -> ScalarField<F1, F2, Self, N>
where
F1: FnMut(&Self) -> &N,
F2: FnMut(&mut Self) -> &mut N,
N: NumCast,
{
ScalarField {
name,
get_ref,
get_mut,
options,
m: Default::default(),
}
}
}

/// An object that can visit [TensorCollection]s and [Tensor]s recursively.
Expand Down Expand Up @@ -137,6 +161,18 @@ pub trait ModuleVisitor<T: TensorCollection<E, D>, E: Dtype, D: Device<E>>: Size
GetRef: FnMut(&T) -> &Tensor<S, E, D>,
GetMut: FnMut(&mut T) -> &mut Tensor<S, E, D>;

fn visit_scalar<N, GetRef, GetMut>(
&mut self,
name: &str,
get_refs: GetRef,
get_muts: GetMut,
opts: ScalarOptions<N>,
) -> Result<Option<N>, Self::Err>
where
N: NumCast,
GetRef: FnMut(&T) -> &N,
GetMut: FnMut(&mut T) -> &mut N;

/// Takes something that implements [ModuleFields] and function that takes
/// [ModuleFields::Output] and returns an instance of T.
fn visit_fields<M: ModuleFields<T, E, D>>(
Expand Down Expand Up @@ -229,3 +265,17 @@ impl<S: Shape, E: Dtype, D: Device<E>> TensorOptions<S, E, D> {
}
}
}

/// Options to change behavior of [ModuleVisitor]
#[non_exhaustive]
pub struct ScalarOptions<N: NumCast> {
/// The default value for this parameter
pub default: N,
}

impl<N: NumCast> ScalarOptions<N> {
// Constructs a ScalarOptions using the parameter's default value
pub fn from_default(default: N) -> Self {
Self { default }
}
}
6 changes: 3 additions & 3 deletions src/nn/tensor_collection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ mod collection;
mod visitor;
mod visitor_impls;

pub use collection::{ModuleVisitor, TensorCollection, TensorOptions};
pub use collection::{ModuleVisitor, ScalarOptions, TensorCollection, TensorOptions};
pub use visitor::{
ModuleField, ModuleFields, RecursiveWalker, TensorField, TensorViewer, TensorVisitor,
ViewTensorMut, ViewTensorName, ViewTensorRef,
ModuleField, ModuleFields, RecursiveWalker, ScalarField, TensorField, TensorViewer,
TensorVisitor, ViewTensorMut, ViewTensorName, ViewTensorRef,
};
Loading
Loading