Skip to content

Commit

Permalink
Merge branch 'main' into serialize_optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoppel committed Jul 8, 2023
2 parents 0dbb1e0 + 91cf54a commit 1d31c57
Show file tree
Hide file tree
Showing 14 changed files with 1,260 additions and 629 deletions.
10 changes: 4 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ no-std-compat = { version = "0.4.1", default-features = false, features = [ "all
spin = { version = "0.9.8", default-features = false, features = ["spin_mutex", "rwlock", "portable_atomic"], optional = true }
rand = { version = "0.8.5", default-features = false, features = ["std_rng"] }
rand_distr = { version = "0.4.3", default-features = false }
zip = { version = "0.6.2", default-features = false, optional = true }
cblas-sys = { version = "0.1.4", default-features = false, optional = true }
libc = { version = "0.2", default-features = false, optional = true }
cudarc = { git = "https://github.com/coreylowman/cudarc", branch = "dfdx-half", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] }
zip = { version = "0.6.6", default-features = false, optional = true }
cudarc = { version = "0.9.11", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] }
num-traits = { version = "0.2.15", default-features = false }
safetensors = { version = "0.3", default-features = false, optional = true }
memmap2 = { version = "0.5", default-features = false, optional = true }
half = { git = "https://github.com/starkat99/half-rs.git", branch = "main", optional = true, features = ["num-traits", "rand_distr"] }
gemm = { version = "0.15.3", default-features = false, optional = true }
half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] }
gemm = { version = "0.15.4", default-features = false, optional = true }
rayon = { version = "1.7.0", optional = true }

[dev-dependencies]
Expand Down
99 changes: 80 additions & 19 deletions src/nn/convtrans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,29 @@ pub mod builder {
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
const GROUPS: usize = 1,
>;
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
BuildOnDevice<D, E> for builder::ConvTrans2D<I, O, K, S, P>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> BuildOnDevice<D, E> for builder::ConvTrans2D<I, O, K, S, P, L, G>
where
E: Dtype,
D: Device<E>,
ConvTrans2D<I, O, K, S, P, E, D>: BuildModule<D, E>,
Const<{ O / G }>: Sized,
ConvTrans2D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = ConvTrans2D<I, O, K, S, P, E, D>;
type Built = ConvTrans2D<I, O, K, S, P, L, G, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
Expand All @@ -45,26 +57,43 @@ where
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`.
/// - `GROUPS`: Controls the connections between inputs and outputs.
/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example,
#[derive(Debug, Clone)]
pub struct ConvTrans2D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize,
const PADDING: usize,
const DILATION: usize,
const GROUPS: usize,
E: Dtype,
D: Storage<E>,
> {
pub weight: Tensor<Rank4<OUT_CHAN, IN_CHAN, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
> where
Const<{ OUT_CHAN / GROUPS }>: Sized,
{
pub weight: Tensor<Rank4<IN_CHAN, { OUT_CHAN / GROUPS }, KERNEL_SIZE, KERNEL_SIZE>, E, D>,
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> TensorCollection<E, D> for ConvTrans2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype + Float + SampleUniform,
D: Device<E>,
Const<{ O / G }>: Sized,
{
type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, E2, D2>;
type To<E2: Dtype, D2: Device<E2>> = ConvTrans2D<I, O, K, S, P, L, G, E2, D2>;

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
Expand All @@ -85,26 +114,58 @@ where
}

#[cfg(feature = "nightly")]
impl<const C: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D, Img>
Module<Img> for ConvTrans2D<C, O, K, S, P, E, D>
impl<
const C: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
Img,
> Module<Img> for ConvTrans2D<C, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Device<E>,
Img: TryConvTrans2DTo<Tensor<Rank4<O, C, K, K>, E, D>, S, P> + HasErr<Err = D::Err>,
Const<{ O / G }>: Sized,
(Img, Tensor<Rank4<C, { O / G }, K, K>, E, D>):
TryConvTrans2D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = Img::Output;
type Error = D::Err;
type Output = <(Img, Tensor<Rank4<C, { O / G }, K, K>, E, D>) as TryConvTrans2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank4<C, { O / G }, K, K>, E, D>) as TryConvTrans2D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, D::Err> {
x.try_convtrans2d_to(self.weight.clone())
fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
(x, self.weight.clone()).try_convtrans2d(Const, Const, Const, Const)
}
}

impl<const I: usize, const O: usize, const K: usize, const S: usize, const P: usize, E, D>
NonMutableModule for ConvTrans2D<I, O, K, S, P, E, D>
impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> NonMutableModule for ConvTrans2D<I, O, K, S, P, L, G, E, D>
where
E: Dtype,
D: Storage<E>,
Const<{ O / G }>: Sized,
{
}

Expand Down Expand Up @@ -187,7 +248,7 @@ mod tests {

assert_ne!(
g.get(&m.weight).array(),
[[[[TestDtype::zero(); 3]; 3]; 2]; 4]
[[[[TestDtype::zero(); 3]; 3]; 4]; 2]
);

opt.update(&mut m, &g).expect("unused params");
Expand Down
1 change: 1 addition & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ mod batchnorm2d;
mod bias2d;
#[cfg(feature = "nightly")]
mod conv;
#[cfg(feature = "nightly")]
mod convtrans;
mod dropout;
mod ema;
Expand Down
46 changes: 24 additions & 22 deletions src/tensor_ops/conv2d/conv2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,29 @@ __device__ void unfold_input_into_patches(
const size_t *strides, // 4d image strides
T *patches // 6d (Batch, Groups * Channels, KernelSize, KernelSize, HeightOut, WidthOut)
) {
const size_t n = op.batch * op.groups * op.chan_in * op.h_out * op.w_out;
const size_t n = op.batch * op.chan_in * op.h_out * op.w_out;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
unsigned int idx = i;
const size_t ow = idx % op.w_out;
idx /= op.w_out;
const size_t oh = idx % op.h_out;
idx /= op.h_out;
const size_t c = idx % (op.chan_in * op.groups);
idx /= (op.chan_in * op.groups);
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t b = idx % op.batch;

const T *image_i = image + b * strides[0] + c * strides[1];
T *patches_i = patches + oh * op.w_out + ow;
patches_i += c * (op.kernel * op.kernel * op.h_out * op.w_out);
patches_i += b * (op.groups * op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out);
patches_i += b * (op.chan_in * op.kernel * op.kernel * op.h_out * op.w_out);

T zero = 0.0;

for (int k1 = 0;k1 < op.kernel;k1++) {
const size_t y = oh * op.stride + op.dilation * k1 - op.padding;
for (int k2 = 0;k2 < op.kernel;k2++) {
const size_t x = ow * op.stride + op.dilation * k2 - op.padding;
*patches_i = (y >= op.h_in || x >= op.w_in) ? zero : image[y * strides[2] + x * strides[3]];
*patches_i = (y >= op.h_in || x >= op.w_in) ? zero : image_i[y * strides[2] + x * strides[3]];
patches_i += op.h_out * op.w_out;
}
}
Expand Down Expand Up @@ -86,7 +86,7 @@ __device__ void unfold_output_into_patches(
const size_t ow = ow_s / op.stride;

const bool invalid = k1_invalid || (ow_ks < op.dilation * k2 || ow_s % op.stride != 0 || ow >= op.w_out);
*patches_i = invalid ? zero : image_out[oh * op.w_out + ow];
*patches_i = invalid ? zero : image_i[oh * op.w_out + ow];
patches_i += op.h_in * op.w_in;
}
}
Expand All @@ -96,64 +96,66 @@ __device__ void unfold_output_into_patches(
template<typename T>
__device__ void transpose_filters(
const Conv2DOp op,
const T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const T *filters, // 4d (ChanOut, ChanIn/Groups, KernelSize, KernelSize)
const size_t *strides, // 4d filters strides
T *filters_tr // 5d (Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize)
T *filters_tr // 5d (Groups, ChanIn/Groups, ChanOut/Groups, KernelSize, KernelSize)
) {
const size_t n = op.chan_in * op.chan_out * op.kernel * op.kernel;
const size_t c_per_g = op.chan_in / op.groups;
const size_t o_per_g = op.chan_out / op.groups;
const size_t n = c_per_g * op.chan_out * op.kernel * op.kernel;

for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
unsigned int idx = i;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t cg = idx % c_per_g;
idx /= c_per_g;
const size_t o = idx % op.chan_out;
const size_t og = o % o_per_g;
const size_t g = o / o_per_g;

auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];
auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2] + k2 * strides[3];
T *filters_tr_i = filters_tr + k2;
filters_tr_i += k1 * op.kernel;
filters_tr_i += og * (op.kernel * op.kernel);
filters_tr_i += c * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel);
filters_tr_i += cg * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (c_per_g * o_per_g * op.kernel * op.kernel);
*filters_tr_i = filters[i_no];
}
}

template<typename T>
__device__ void sum_transposed_filters(
const Conv2DOp op,
const T *filters_tr, // 6d (Batch, Groups, ChanIn, ChanOut/Groups, KernelSize, KernelSize)
T *filters, // 4d (ChanOut, ChanIn, KernelSize, KernelSize)
const T *filters_tr, // 6d (Batch, Groups, ChanIn/Groups, ChanOut/Groups, KernelSize, KernelSize)
T *filters, // 4d (ChanOut, ChanIn/Groups, KernelSize, KernelSize)
const size_t *strides // 4d filter strides
) {
const size_t n = op.chan_out * op.chan_in * op.kernel * op.kernel;
const size_t o_per_g = op.chan_out / op.groups;
const size_t c_per_g = op.chan_in / op.groups;
const size_t n = op.chan_out * c_per_g * op.kernel * op.kernel;

for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
unsigned int idx = i;
const size_t k2 = idx % op.kernel;
idx /= op.kernel;
const size_t k1 = idx % op.kernel;
idx /= op.kernel;
const size_t c = idx % op.chan_in;
idx /= op.chan_in;
const size_t cg = idx % c_per_g;
idx /= c_per_g;
const size_t o = idx % op.chan_out;
const size_t og = o % o_per_g;
const size_t g = o / o_per_g;

auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3];
auto i_no = o * strides[0] + cg * strides[1] + k1 * strides[2] + k2 * strides[3];

const T *filters_tr_i = filters_tr + k2;
filters_tr_i += k1 * op.kernel;
filters_tr_i += og * (op.kernel * op.kernel);
filters_tr_i += c * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (op.chan_in * o_per_g * op.kernel * op.kernel);
filters_tr_i += cg * (o_per_g * op.kernel * op.kernel);
filters_tr_i += g * (c_per_g * o_per_g * op.kernel * op.kernel);

T tmp = 0.0;
for (int b = 0; b < op.batch; b++) {
Expand Down
Loading

0 comments on commit 1d31c57

Please sign in to comment.