Skip to content

Commit

Permalink
[Breaking] Moving storage GAT to trait level generic. Split DeviceSto…
Browse files Browse the repository at this point in the history
…rage into multiple traits (#782)

* [Breaking] Rename DeviceStorage -> Storage. Move GAT on Vec to trait Generic

* Fixing cuda kernels usage of Storage

* Removing SampleUniform bound from Dtype

* Fixing docs

* Fixing docs
  • Loading branch information
coreylowman authored May 12, 2023
1 parent fe2de4f commit 37a711a
Show file tree
Hide file tree
Showing 118 changed files with 605 additions and 571 deletions.
2 changes: 1 addition & 1 deletion examples/07-custom-module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct Mlp<const IN: usize, const INNER: usize, const OUT: usize, E: Dtype, D: D
impl<const IN: usize, const INNER: usize, const OUT: usize, E, D: Device<E>> TensorCollection<E, D>
for Mlp<IN, INNER, OUT, E, D>
where
E: Dtype + num_traits::Float,
E: Dtype + num_traits::Float + rand_distr::uniform::SampleUniform,
{
// Type alias that specifies the how Mlp's type changes when using a different dtype and/or
// device.
Expand Down
4 changes: 2 additions & 2 deletions src/data/arange.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::{
shapes::*,
tensor::{DeviceStorage, Tensor, TensorFromVec, ZerosTensor},
tensor::{Storage, Tensor, TensorFromVec, ZerosTensor},
};

use std::vec::Vec;

/// Generates a tensor with ordered data from 0 to `N`.
pub trait Arange<E: Dtype>: DeviceStorage + ZerosTensor<E> + TensorFromVec<E> {
pub trait Arange<E: Dtype>: Storage<E> + ZerosTensor<E> + TensorFromVec<E> {
/// Generates a tensor with ordered data from 0 to `N`.
///
/// Const sized tensor:
Expand Down
6 changes: 3 additions & 3 deletions src/data/one_hot_encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use std::vec::Vec;

use crate::{
shapes::*,
tensor::{DeviceStorage, Tensor, TensorFromVec, ZerosTensor},
tensor::{Storage, Tensor, TensorFromVec, ZerosTensor},
};

/// One hot encodes an array of class labels into a 2d tensor of probability
/// vectors. This can be used in tandem with [crate::losses::cross_entropy_with_logits_loss()].
pub trait OneHotEncode<E: Dtype>: DeviceStorage + ZerosTensor<E> + TensorFromVec<E> {
pub trait OneHotEncode<E: Dtype>: Storage<E> + ZerosTensor<E> + TensorFromVec<E> {
/// One hot encodes an array or vec into a tensor.
///
/// Arguments:
Expand Down Expand Up @@ -92,4 +92,4 @@ pub trait OneHotEncode<E: Dtype>: DeviceStorage + ZerosTensor<E> + TensorFromVec
self.tensor_from_vec(data, (l, n))
}
}
impl<E: Dtype, D: DeviceStorage + ZerosTensor<E> + TensorFromVec<E>> OneHotEncode<E> for D {}
impl<E: Dtype, D: Storage<E> + ZerosTensor<E> + TensorFromVec<E>> OneHotEncode<E> for D {}
2 changes: 1 addition & 1 deletion src/nn/batchnorm1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ where
/// - Running statistics: **not** updated
/// - Normalization: calculated using running stats
#[derive(Clone, Debug)]
pub struct BatchNorm1D<const C: usize, E: Dtype, D: DeviceStorage> {
pub struct BatchNorm1D<const C: usize, E: Dtype, D: Storage<E>> {
/// Scale for affine transform. Defaults to 1.0
pub scale: Tensor<Rank1<C>, E, D>,
/// Bias for affine transform. Defaults to 0.0
Expand Down
2 changes: 1 addition & 1 deletion src/nn/batchnorm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ where
/// - Running statistics: **not** updated
/// - Normalization: calculated using running stats
#[derive(Clone, Debug)]
pub struct BatchNorm2D<const C: usize, E: Dtype, D: DeviceStorage> {
pub struct BatchNorm2D<const C: usize, E: Dtype, D: Storage<E>> {
/// Scale for affine transform. Defaults to 1.0
pub scale: Tensor<Rank1<C>, E, D>,
/// Bias for affine transform. Defaults to 0.0
Expand Down
4 changes: 2 additions & 2 deletions src/nn/bias2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ where
/// model.forward(x);
/// ```
#[derive(Clone, Debug)]
pub struct Bias2D<const C: usize, E: Dtype, D: DeviceStorage> {
pub struct Bias2D<const C: usize, E: Dtype, D: Storage<E>> {
pub bias: Tensor<Rank1<C>, E, D>,
}

impl<const C: usize, E: Dtype, D: DeviceStorage> NonMutableModule for Bias2D<C, E, D> {}
impl<const C: usize, E: Dtype, D: Storage<E>> NonMutableModule for Bias2D<C, E, D> {}

impl<const C: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for Bias2D<C, E, D> {
type To<E2: Dtype, D2: Device<E2>> = Bias2D<C, E2, D2>;
Expand Down
16 changes: 11 additions & 5 deletions src/nn/build_module.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use super::tensor_collection::*;
use crate::{
shapes::{Dtype, Shape},
tensor::{DeviceStorage, Tensor},
tensor::{Storage, Tensor},
tensor_ops::Device,
};

struct Builder<'a, D: DeviceStorage>(&'a D);
impl<'a, E: Dtype, D: Device<E>> TensorVisitor<E, D> for Builder<'a, D> {
struct Builder<'a, E, D: Storage<E>> {
device: &'a D,
dtype: std::marker::PhantomData<E>,
}
impl<'a, E: Dtype, D: Device<E>> TensorVisitor<E, D> for Builder<'a, E, D> {
type Viewer = ();
type Err = D::Err;
type E2 = E;
Expand All @@ -17,7 +20,7 @@ impl<'a, E: Dtype, D: Device<E>> TensorVisitor<E, D> for Builder<'a, D> {
opts: TensorOptions<S, E, D>,
_t: (),
) -> Result<Option<Tensor<S, E, D>>, Self::Err> {
let mut tensor: Tensor<S, E, D> = self.0.try_zeros_like(&opts.shape)?;
let mut tensor: Tensor<S, E, D> = self.device.try_zeros_like(&opts.shape)?;
(opts.reset)(&mut tensor)?;
Ok(Some(tensor))
}
Expand All @@ -36,7 +39,10 @@ pub trait BuildModule<D: Device<E>, E: Dtype>:
fn try_build(device: &D) -> Result<Self, D::Err> {
let out = Self::iter_tensors(&mut RecursiveWalker {
m: (),
f: &mut Builder(device),
f: &mut Builder {
device,
dtype: std::marker::PhantomData,
},
})?;

Ok(out.unwrap())
Expand Down
9 changes: 3 additions & 6 deletions src/nn/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub struct Conv2D<
const DILATION: usize,
const GROUPS: usize,
E: Dtype,
D: DeviceStorage,
D: Storage<E>,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}
Expand Down Expand Up @@ -156,12 +156,9 @@ impl<
const P: usize,
const L: usize,
const G: usize,
E,
D,
E: Dtype,
D: Storage<E>,
> NonMutableModule for Conv2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: DeviceStorage,
{
}

Expand Down
4 changes: 2 additions & 2 deletions src/nn/convtrans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub struct ConvTrans2D<
const STRIDE: usize,
const PADDING: usize,
E: Dtype,
D: DeviceStorage,
D: Storage<E>,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}
Expand Down Expand Up @@ -104,7 +104,7 @@ impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: us
NonMutableModule for ConvTrans2D<I, O, K, S, P, E, D>
where
E: Dtype,
D: DeviceStorage,
D: Storage<E>,
{
}

Expand Down
4 changes: 2 additions & 2 deletions src/nn/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ where
/// let _: Tensor<(Const<10>, Const<5>, Const<2>), f32, _> = model.forward(inputs);
/// ```
#[derive(Debug, Clone)]
pub struct Embedding<const VOCAB: usize, const DIM: usize, E: Dtype, D: DeviceStorage> {
pub struct Embedding<const VOCAB: usize, const DIM: usize, E: Dtype, D: Storage<E>> {
/// Transposed weight matrix, shape (I, O)
pub weight: Tensor<Rank2<VOCAB, DIM>, E, D>,
}

impl<const V: usize, const M: usize, E: Dtype, D: DeviceStorage> NonMutableModule
impl<const V: usize, const M: usize, E: Dtype, D: Storage<E>> NonMutableModule
for Embedding<V, M, E, D>
{
}
Expand Down
32 changes: 31 additions & 1 deletion src/nn/impl_module_for_tuples.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
use crate::{shapes::*, tensor_ops::*};
use crate::{shapes::*, tensor::HasErr, tensor_ops::*};

use super::*;

impl<E: Dtype, D: Device<E>> TensorCollection<E, D> for () {
type To<E2: Dtype, D2: Device<E2>> = ();

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
_: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
Ok(None)
}
}

impl<D: Device<E>, E: Dtype> BuildOnDevice<D, E> for () {
type Built = ();
}

impl<X: HasErr> Module<X> for () {
type Output = X;
type Error = X::Err;
fn try_forward(&self, input: X) -> Result<Self::Output, Self::Error> {
Ok(input)
}
}

impl<X: HasErr> ModuleMut<X> for () {
type Output = X;
type Error = X::Err;
fn try_forward_mut(&mut self, input: X) -> Result<Self::Output, Self::Error> {
Ok(input)
}
}

macro_rules! tuple_impls {
([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),*]) => {
impl<E: Dtype, D: Device<E>, $($name: TensorCollection<E, D>),+> TensorCollection<E, D> for ($($name,)+) {
Expand Down
4 changes: 2 additions & 2 deletions src/nn/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ where
/// ```

#[derive(Debug, Clone)]
pub struct LayerNorm1D<const M: usize, E: Dtype, D: DeviceStorage> {
pub struct LayerNorm1D<const M: usize, E: Dtype, D: Storage<E>> {
pub gamma: Tensor<Rank1<M>, E, D>,
pub beta: Tensor<Rank1<M>, E, D>,
pub epsilon: f64,
}

impl<const M: usize, E: Dtype, D: DeviceStorage> NonMutableModule for LayerNorm1D<M, E, D> {}
impl<const M: usize, E: Dtype, D: Storage<E>> NonMutableModule for LayerNorm1D<M, E, D> {}

impl<const M: usize, E: Dtype, D: Device<E>> TensorCollection<E, D> for LayerNorm1D<M, E, D> {
type To<E2: Dtype, D2: Device<E2>> = LayerNorm1D<M, E2, D2>;
Expand Down
14 changes: 9 additions & 5 deletions src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,25 @@ where
/// let _: Tensor<Rank2<10, 2>, f32, _> = model.forward(dev.zeros::<Rank2<10, 5>>());
/// ```
#[derive(Debug, Clone)]
pub struct Linear<const I: usize, const O: usize, E: Dtype, D: DeviceStorage> {
pub struct Linear<const I: usize, const O: usize, E: Dtype, D: Storage<E>> {
/// Transposed weight matrix, shape (I, O)
pub weight: Tensor<Rank2<O, I>, E, D>,

/// Bias vector, shape (O, )
pub bias: Tensor<Rank1<O>, E, D>,
}

impl<const I: usize, const O: usize, E: Dtype, D: DeviceStorage> NonMutableModule
impl<const I: usize, const O: usize, E: Dtype, D: Storage<E>> NonMutableModule
for Linear<I, O, E, D>
{
}

impl<const I: usize, const O: usize, E: Dtype + num_traits::Float, D: Device<E>>
TensorCollection<E, D> for Linear<I, O, E, D>
impl<
const I: usize,
const O: usize,
E: Dtype + num_traits::Float + rand_distr::uniform::SampleUniform,
D: Device<E>,
> TensorCollection<E, D> for Linear<I, O, E, D>
{
type To<E2: Dtype, D2: Device<E2>> = Linear<I, O, E2, D2>;

Expand Down Expand Up @@ -107,7 +111,7 @@ where
}

#[derive(Clone, Debug)]
struct Bias1D<'a, const M: usize, E: Dtype, D: DeviceStorage> {
struct Bias1D<'a, const M: usize, E: Dtype, D: Storage<E>> {
beta: &'a Tensor<Rank1<M>, E, D>,
}

Expand Down
8 changes: 5 additions & 3 deletions src/nn/module.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub use super::build_module::BuildModule;
pub use super::to_device::*;

use crate::{shapes::Dtype, tensor::DeviceStorage, tensor_ops::Device};
use crate::{shapes::Dtype, tensor_ops::Device};

use super::tensor_collection::*;

Expand Down Expand Up @@ -55,7 +55,7 @@ pub trait BuildOnDevice<D: Device<E>, E: Dtype> {

/// An extension trait that allows you to build a module with a device
/// method. Also allows easy specification of Dtype.
pub trait DeviceBuildExt: DeviceStorage {
pub trait DeviceBuildExt {
fn build_module<M: BuildOnDevice<Self, E>, E: Dtype>(&self) -> M::Built
where
Self: Device<E>,
Expand All @@ -69,7 +69,9 @@ pub trait DeviceBuildExt: DeviceStorage {
M::try_build_on_device(self)
}
}
impl<D: DeviceStorage> DeviceBuildExt for D {}
impl DeviceBuildExt for crate::tensor::Cpu {}
#[cfg(feature = "cuda")]
impl DeviceBuildExt for crate::tensor::Cuda {}

/// Marker trait for modules with no updatable parameters. These have
/// blanket impls for, and [ModuleMut]
Expand Down
2 changes: 1 addition & 1 deletion src/nn/tensor_collection/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use super::{ModuleField, ModuleFields, TensorField};
/// relu: ReLU,
/// }
///
/// impl<E: Dtype + num_traits::Float, D: Device<E>> TensorCollection<E, D> for Mlp<E, D> {
/// impl<E: Dtype + num_traits::Float + rand_distr::uniform::SampleUniform, D: Device<E>> TensorCollection<E, D> for Mlp<E, D> {
/// type To<E2: Dtype, D2: Device<E2>> = Mlp<E2, D2>;
///
/// fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
Expand Down
6 changes: 3 additions & 3 deletions src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rand_distr::uniform::SampleUniform;
use crate::{
nn::modules::*,
shapes::Dtype,
tensor::{DeviceStorage, HasErr, PutTape, SplitTape},
tensor::{HasErr, PutTape, SplitTape, Storage},
tensor_ops::{Device, TryAdd},
};

Expand Down Expand Up @@ -64,7 +64,7 @@ pub struct TransformerDecoder<
const FF_DIM: usize,
const NUM_LAYERS: usize,
E: Dtype,
D: DeviceStorage,
D: Storage<E>,
>(pub Repeated<TransformerDecoderBlock<MODEL_DIM, NUM_HEADS, FF_DIM, E, D>, NUM_LAYERS>);

impl<const M: usize, const H: usize, const F: usize, const L: usize, E: Dtype, D: Device<E>>
Expand Down Expand Up @@ -127,7 +127,7 @@ pub struct TransformerDecoderBlock<
const NUM_HEADS: usize,
const FF_DIM: usize,
E: Dtype,
D: DeviceStorage,
D: Storage<E>,
> {
pub self_attn: MultiHeadAttention<MODEL_DIM, NUM_HEADS, MODEL_DIM, MODEL_DIM, E, D>,
pub norm1: LayerNorm1D<MODEL_DIM, E, D>,
Expand Down
4 changes: 2 additions & 2 deletions src/nn/transformer/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rand_distr::uniform::SampleUniform;
use crate::{
nn::modules::*,
shapes::Dtype,
tensor::{DeviceStorage, PutTape, SplitTape},
tensor::{PutTape, SplitTape, Storage},
tensor_ops::Device,
};

Expand Down Expand Up @@ -85,7 +85,7 @@ pub struct TransformerEncoderBlock<
const NUM_HEADS: usize,
const FF_DIM: usize,
E: Dtype,
D: DeviceStorage,
D: Storage<E>,
> {
pub self_attn: MultiHeadAttention<MODEL_DIM, NUM_HEADS, MODEL_DIM, MODEL_DIM, E, D>,
pub norm1: LayerNorm1D<MODEL_DIM, E, D>,
Expand Down
2 changes: 1 addition & 1 deletion src/nn/transformer/mha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub struct MultiHeadAttention<
const K_DIM: usize,
const V_DIM: usize,
E: Dtype,
D: DeviceStorage,
D: Storage<E>,
> {
pub w_q: Linear<EMBED_DIM, K_DIM, E, D>,
pub w_k: Linear<EMBED_DIM, K_DIM, E, D>,
Expand Down
2 changes: 1 addition & 1 deletion src/nn/transformer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub struct Transformer<
const NUM_DECODER_LAYERS: usize,
const FF_DIM: usize,
E: Dtype,
D: DeviceStorage,
D: Storage<E>,
> {
pub encoder: TransformerEncoder<MODEL_DIM, NUM_HEADS, FF_DIM, NUM_ENCODER_LAYERS, E, D>,
pub decoder: TransformerDecoder<MODEL_DIM, NUM_HEADS, FF_DIM, NUM_DECODER_LAYERS, E, D>,
Expand Down
4 changes: 2 additions & 2 deletions src/nn/unbiased_linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ where
/// let _: Tensor<Rank2<10, 2>, f32, _> = model.forward(dev.zeros::<Rank2<10, 5>>());
/// ```
#[derive(Debug, Clone)]
pub struct UnbiasedLinear<const I: usize, const O: usize, E: Dtype, D: DeviceStorage> {
pub struct UnbiasedLinear<const I: usize, const O: usize, E: Dtype, D: Storage<E>> {
/// Transposed weight matrix, shape (I, O)
pub weight: Tensor<Rank2<O, I>, E, D>,
}

impl<const I: usize, const O: usize, E: Dtype, D: DeviceStorage> NonMutableModule
impl<const I: usize, const O: usize, E: Dtype, D: Storage<E>> NonMutableModule
for UnbiasedLinear<I, O, E, D>
{
}
Expand Down
2 changes: 1 addition & 1 deletion src/nn/zero_grads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub trait ZeroGrads<E: Dtype, D: Device<E>>: TensorCollection<E, D> {
}
impl<E: Dtype, D: Device<E>, M: TensorCollection<E, D>> ZeroGrads<E, D> for M {}

struct ZeroGradOp<'a, E: Unit, D: DeviceStorage> {
struct ZeroGradOp<'a, E: Unit, D: Storage<E>> {
updated: Vec<UniqueId>,
gradients: &'a mut Gradients<E, D>,
}
Expand Down
Loading

0 comments on commit 37a711a

Please sign in to comment.