Skip to content

Commit

Permalink
impl sum for amp cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 24, 2023
1 parent 051aecf commit 207653c
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion src/tensor_ops/sum_to/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,78 @@ use crate::{
tensor_ops::utilities::reduction_utils::index_for_reductions,
};

impl<E: Dtype> super::SumKernel<E> for Cpu {
#[cfg(feature = "f16")]
impl super::SumKernel<crate::dtypes::AMP<crate::dtypes::f16>> for Cpu {
fn forward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
dst: Dst,
inp: &Tensor<Src, crate::dtypes::AMP<crate::dtypes::f16>, Self>,
) -> Result<Tensor<Dst, crate::dtypes::AMP<crate::dtypes::f16>, Self>, Self::Err>
where
Src: ReduceShapeTo<Dst, Ax>,
{
let mut out = self.try_zeros_like(&dst)?;
if Dst::NUM_DIMS == 0 {
debug_assert_eq!(out.data.len(), 1);

let mut tmp = 0.0f32;
for v in inp.buf_iter() {
tmp += v.0.to_f32();
}
let scale = (inp.shape.num_elements() / inp.data.len()) as f32;
std::sync::Arc::get_mut(&mut out.data).unwrap()[0] =
crate::dtypes::AMP(crate::dtypes::f16::from_f32(tmp * scale));
} else {
let num_elems_reduced = <Src as HasAxes<Ax>>::size(&inp.shape);
let inp_buf = inp.data.as_ref();
let mut idx = index_for_reductions::<Src, Ax>(inp.shape, inp.strides);
for o in out.buf_iter_mut() {
let mut tmp = 0.0f32;
for _ in 0..num_elems_reduced {
tmp += inp_buf[idx.next().unwrap()].0.to_f32();
}
*o = crate::dtypes::AMP(crate::dtypes::f16::from_f32(tmp));
}
}
Ok(out)
}
fn backward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
_dst: Dst,
inp: &impl Tensorlike<Src, crate::dtypes::AMP<crate::dtypes::f16>, Self>,
grad_inp: &mut Self::Vec,
grad_out: &Self::Vec,
) -> Result<(), Self::Err>
where
Src: ReduceShapeTo<Dst, Ax>,
{
if Dst::NUM_DIMS == 0 {
debug_assert_eq!(grad_out.len(), 1);
let v = grad_out[0].0.to_f32();
let scale = (inp.shape().num_elements() / inp.len()) as f32;
for i in grad_inp.iter_mut() {
i.0 += crate::dtypes::f16::from_f32(v * scale);
}
} else {
let num_elems_reduced = <Src as HasAxes<Ax>>::size(inp.shape());
let mut idx = index_for_reductions::<Src, Ax>(*inp.shape(), inp.strides());
for &o in grad_out.iter() {
for _ in 0..num_elems_reduced {
grad_inp[idx.next().unwrap()] += o;
}
}
}
Ok(())
}
}

trait NonMixedPrecision {}
#[cfg(feature = "f16")]
impl NonMixedPrecision for crate::dtypes::f16 {}
impl NonMixedPrecision for f32 {}
impl NonMixedPrecision for f64 {}

impl<E: Dtype + NonMixedPrecision> super::SumKernel<E> for Cpu {
fn forward<Src: Shape, Dst: Shape, Ax: Axes>(
&self,
dst: Dst,
Expand Down

0 comments on commit 207653c

Please sign in to comment.