Skip to content

Commit

Permalink
Moving NotMixedPrecision to dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 24, 2023
1 parent 4f5e138 commit f4a47fd
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 32 deletions.
18 changes: 18 additions & 0 deletions src/dtypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,21 @@ impl Dtype for half::f16 {}
pub trait HasDtype {
type Dtype: Dtype;
}

pub trait NotMixedPrecision {}
impl NotMixedPrecision for f32 {}
impl NotMixedPrecision for f64 {}
impl NotMixedPrecision for i8 {}
impl NotMixedPrecision for i16 {}
impl NotMixedPrecision for i32 {}
impl NotMixedPrecision for i64 {}
impl NotMixedPrecision for i128 {}
impl NotMixedPrecision for isize {}
impl NotMixedPrecision for u8 {}
impl NotMixedPrecision for u16 {}
impl NotMixedPrecision for u32 {}
impl NotMixedPrecision for u64 {}
impl NotMixedPrecision for u128 {}
impl NotMixedPrecision for usize {}
#[cfg(feature = "f16")]
impl NotMixedPrecision for half::f16 {}
13 changes: 5 additions & 8 deletions src/tensor_ops/adam/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use super::{AdamConfig, AdamKernel, WeightDecay};
use crate::{shapes::Dtype, tensor::Cpu};

trait NonMixedPrecision {}
#[cfg(feature = "f16")]
impl NonMixedPrecision for crate::dtypes::f16 {}
impl NonMixedPrecision for f32 {}
impl NonMixedPrecision for f64 {}
use crate::{
dtypes::{Dtype, NotMixedPrecision},
tensor::Cpu,
};

#[cfg(feature = "f16")]
impl AdamKernel<crate::dtypes::AMP<crate::dtypes::f16>> for Cpu {
Expand Down Expand Up @@ -54,7 +51,7 @@ impl AdamKernel<crate::dtypes::AMP<crate::dtypes::f16>> for Cpu {
}
}

impl<E: num_traits::Float + Dtype + NonMixedPrecision> AdamKernel<E> for Cpu {
impl<E: num_traits::Float + Dtype + NotMixedPrecision> AdamKernel<E> for Cpu {
fn adam_kernel(
&self,
t: i32,
Expand Down
13 changes: 5 additions & 8 deletions src/tensor_ops/rmsprop/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use crate::{shapes::Dtype, tensor::cpu::Cpu};
use crate::{
dtypes::{Dtype, NotMixedPrecision},
tensor::cpu::Cpu,
};

use super::{RMSpropConfig, RMSpropKernel, WeightDecay};

trait NonMixedPrecision {}
#[cfg(feature = "f16")]
impl NonMixedPrecision for crate::dtypes::f16 {}
impl NonMixedPrecision for f32 {}
impl NonMixedPrecision for f64 {}

#[cfg(feature = "f16")]
impl RMSpropKernel<crate::dtypes::AMP<crate::dtypes::f16>> for Cpu {
fn rmsprop_kernel(
Expand Down Expand Up @@ -74,7 +71,7 @@ impl RMSpropKernel<crate::dtypes::AMP<crate::dtypes::f16>> for Cpu {
}
}

impl<E: num_traits::Float + Dtype + NonMixedPrecision> RMSpropKernel<E> for Cpu {
impl<E: num_traits::Float + Dtype + NotMixedPrecision> RMSpropKernel<E> for Cpu {
fn rmsprop_kernel(
&self,
cfg: &RMSpropConfig,
Expand Down
13 changes: 5 additions & 8 deletions src/tensor_ops/sgd/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use crate::{shapes::Dtype, tensor::cpu::*};
use crate::{
dtypes::{Dtype, NotMixedPrecision},
tensor::cpu::*,
};

use super::{Momentum, SgdConfig, SgdKernel, WeightDecay};

trait NonMixedPrecision {}
#[cfg(feature = "f16")]
impl NonMixedPrecision for crate::dtypes::f16 {}
impl NonMixedPrecision for f32 {}
impl NonMixedPrecision for f64 {}

#[cfg(feature = "f16")]
impl SgdKernel<crate::dtypes::AMP<crate::dtypes::f16>> for Cpu {
fn sgd_kernel(
Expand Down Expand Up @@ -58,7 +55,7 @@ impl SgdKernel<crate::dtypes::AMP<crate::dtypes::f16>> for Cpu {
}
}

impl<E: Dtype + NonMixedPrecision> SgdKernel<E> for Cpu {
impl<E: Dtype + NotMixedPrecision> SgdKernel<E> for Cpu {
fn sgd_kernel(
&self,
cfg: &SgdConfig,
Expand Down
11 changes: 3 additions & 8 deletions src/tensor_ops/sum_to/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{
shapes::{Axes, Dtype, HasAxes, ReduceShapeTo, Shape},
dtypes::{Dtype, NotMixedPrecision},
shapes::{Axes, HasAxes, ReduceShapeTo, Shape},
tensor::{Cpu, Tensor, Tensorlike, ZerosTensor},
tensor_ops::utilities::reduction_utils::index_for_reductions,
};
Expand Down Expand Up @@ -69,13 +70,7 @@ impl super::SumKernel<crate::dtypes::AMP<crate::dtypes::f16>> for Cpu {
}
}

trait NonMixedPrecision {}
#[cfg(feature = "f16")]
impl NonMixedPrecision for crate::dtypes::f16 {}
impl NonMixedPrecision for f32 {}
impl NonMixedPrecision for f64 {}

impl<E: Dtype + NonMixedPrecision> super::SumKernel<E> for Cpu {
impl<E: Dtype + NotMixedPrecision> super::SumKernel<E> for Cpu {
fn forward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
dst: Dst,
Expand Down

0 comments on commit f4a47fd

Please sign in to comment.