Skip to content

Commit

Permalink
Changing scalar rhs to Into<f64> for binary ops (#864)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Sep 7, 2023
1 parent 67f2568 commit 9c528eb
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 154 deletions.
9 changes: 2 additions & 7 deletions examples/02-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,8 @@ fn main() {
dbg!(f.array());

// and of course you can chain all of these together
let _ = dev
.sample_normal::<Rank2<5, 10>>()
.clamp(-1.0, 1.0)
.exp()
.abs()
.powf(0.5)
/ 2.0;
let _: Tensor<(Const<5>, Const<10>), f32, _> =
dev.sample_normal().clamp(-1.0, 1.0).exp().abs().powf(0.5) / 2.0;

// binary and unary operations can also be performed on dynamically sized tensors
let mut a: Tensor<(Const<3>, usize), f32, _> = dev.sample_uniform_like(&(Const, 5));
Expand Down
7 changes: 2 additions & 5 deletions src/losses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ pub fn huber_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
pub fn smooth_l1_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
pred: Tensor<S, E, D, T>,
targ: Tensor<S, E, D>,
delta: impl Into<f64>,
delta: impl Copy + Into<f64>,
) -> Tensor<Rank0, E, D, T> {
let delta: f64 = delta.into();
huber_loss(pred, targ, delta) / E::from_f64(delta).unwrap()
huber_loss(pred, targ, delta) / delta
}

/// [Cross entropy loss](https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_loss_function_and_logistic_regression).
Expand All @@ -83,7 +82,6 @@ pub fn cross_entropy_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<
target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
let inv_last_axis_numel = 1.0 / <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
let inv_last_axis_numel = E::from_f64(inv_last_axis_numel).unwrap();
let probs = logits.log_softmax::<S::LastAxis>();
(probs * target_probs).mean().negate() / inv_last_axis_numel
}
Expand All @@ -103,7 +101,6 @@ pub fn kl_div_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
let inv_last_axis_numel = 1.0 / <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
let inv_last_axis_numel = E::from_f64(inv_last_axis_numel).unwrap();
let probs = logits.log_softmax::<S::LastAxis>();
((probs - target_probs.clone().ln()) * target_probs)
.mean()
Expand Down
9 changes: 2 additions & 7 deletions src/nn/batchnorm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ where
var.try_axpy(1.0 - momentum, &var_chan, momentum * n / (n - 1.0))?;

// statistics for normalizing - on tape
let std = var_chan
.try_add(E::from_f64(epsilon).unwrap())?
.try_sqrt()?;
let std = var_chan.try_add(epsilon)?.try_sqrt()?;

// record broadcast of scale & bias - on tape
let scale = scale
Expand Down Expand Up @@ -81,10 +79,7 @@ where
let shape = *x.shape();

// statistics for normalizing
let std = var
.clone()
.try_add(E::from_f64(epsilon).unwrap())?
.try_sqrt()?;
let std = var.clone().try_add(epsilon)?.try_sqrt()?;

let scale = scale.clone().try_div(std)?.try_broadcast_like(&shape)?;

Expand Down
2 changes: 1 addition & 1 deletion src/nn/transformer/mha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ where
let q = q.try_permute::<_, Axes4<0, 2, 1, 3>>()?;

// Get weights
let scalar: E = E::from_f64(1.0 / ((K / H) as f64).sqrt()).unwrap();
let scalar = 1.0 / ((K / H) as f64).sqrt();
let weights = q.try_matmul(k)?.try_mul(scalar)?;
let weights = weights.try_softmax::<Axis<3>>()?;

Expand Down
39 changes: 6 additions & 33 deletions src/tensor_ops/add/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,42 +67,15 @@ where
}
}

impl<S: Shape, E: Dtype, D: UnaryKernel<ScalarAddKernelOp<E>, E>, T: Tape<E, D>> TryAdd<E>
for Tensor<S, E, D, T>
{
type Output = Self;
/// See [add]
fn try_add(self, rhs: E) -> Result<Self, Self::Err> {
try_unary_op(ScalarAddKernelOp { scalar: rhs }, self)
}
}

#[cfg(feature = "f16")]
impl<S: Shape, D: UnaryKernel<ScalarAddKernelOp<half::f16>, half::f16>, T: Tape<half::f16, D>>
TryAdd<f32> for Tensor<S, half::f16, D, T>
{
type Output = Self;
/// See [add]
fn try_add(self, rhs: f32) -> Result<Self, Self::Err> {
let scalar = half::f16::from_f32(rhs);
try_unary_op(ScalarAddKernelOp { scalar }, self)
}
}

#[cfg(feature = "f16")]
impl<
S: Shape,
D: UnaryKernel<
ScalarAddKernelOp<crate::dtypes::AMP<half::f16>>,
crate::dtypes::AMP<half::f16>,
>,
T: Tape<crate::dtypes::AMP<half::f16>, D>,
> TryAdd<f32> for Tensor<S, crate::dtypes::AMP<half::f16>, D, T>
impl<S: Shape, E: Dtype, Rhs: Into<f64>, D, T: Tape<E, D>> TryAdd<Rhs> for Tensor<S, E, D, T>
where
D: UnaryKernel<ScalarAddKernelOp<E>, E>,
{
type Output = Self;
/// See [add]
fn try_add(self, rhs: f32) -> Result<Self, Self::Err> {
let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs));
fn try_add(self, rhs: Rhs) -> Result<Self, Self::Err> {
let rhs: f64 = rhs.into();
let scalar = E::from_f64(rhs).unwrap();
try_unary_op(ScalarAddKernelOp { scalar }, self)
}
}
Expand Down
39 changes: 6 additions & 33 deletions src/tensor_ops/div/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,42 +65,15 @@ where
}
}

impl<S: Shape, E: Dtype, D: UnaryKernel<ScalarDivKernelOp<E>, E>, T: Tape<E, D>> TryDiv<E>
for Tensor<S, E, D, T>
{
type Output = Self;
/// See [div]
fn try_div(self, rhs: E) -> Result<Self, Self::Err> {
try_unary_op(ScalarDivKernelOp { scalar: rhs }, self)
}
}

#[cfg(feature = "f16")]
impl<S: Shape, D: UnaryKernel<ScalarDivKernelOp<half::f16>, half::f16>, T: Tape<half::f16, D>>
TryDiv<f32> for Tensor<S, half::f16, D, T>
{
type Output = Self;
/// See [div]
fn try_div(self, rhs: f32) -> Result<Self, Self::Err> {
let scalar = half::f16::from_f32(rhs);
try_unary_op(ScalarDivKernelOp { scalar }, self)
}
}

#[cfg(feature = "f16")]
impl<
S: Shape,
D: UnaryKernel<
ScalarDivKernelOp<crate::dtypes::AMP<half::f16>>,
crate::dtypes::AMP<half::f16>,
>,
T: Tape<crate::dtypes::AMP<half::f16>, D>,
> TryDiv<f32> for Tensor<S, crate::dtypes::AMP<half::f16>, D, T>
impl<S: Shape, E: Dtype, Rhs: Into<f64>, D, T: Tape<E, D>> TryDiv<Rhs> for Tensor<S, E, D, T>
where
D: UnaryKernel<ScalarDivKernelOp<E>, E>,
{
type Output = Self;
/// See [div]
fn try_div(self, rhs: f32) -> Result<Self, Self::Err> {
let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs));
fn try_div(self, rhs: Rhs) -> Result<Self, Self::Err> {
let rhs: f64 = rhs.into();
let scalar = E::from_f64(rhs).unwrap();
try_unary_op(ScalarDivKernelOp { scalar }, self)
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/tensor_ops/mean_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> MeanTo for Tensor<S, E, D,
Self::Shape: HasAxes<Ax> + ReduceShapeTo<Dst, Ax>,
{
let num_elements_reduced = <S as HasAxes<Ax>>::size(self.shape()) as f64;
let inv_normalize = E::from_f64(1.0 / num_elements_reduced).unwrap();
self.try_sum()?.try_mul(inv_normalize)
self.try_sum()?.try_mul(1.0 / num_elements_reduced)
}
}

Expand Down
37 changes: 6 additions & 31 deletions src/tensor_ops/mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,39 +62,14 @@ where
}
}

impl<S: Shape, E: Dtype, D: UnaryKernel<ScalarMulKernelOp<E>, E>, T: Tape<E, D>> TryMul<E>
for Tensor<S, E, D, T>
{
type Output = Self;
fn try_mul(self, rhs: E) -> Result<Self, Self::Err> {
try_unary_op(ScalarMulKernelOp { scalar: rhs }, self)
}
}

#[cfg(feature = "f16")]
impl<S: Shape, D: UnaryKernel<ScalarMulKernelOp<half::f16>, half::f16>, T: Tape<half::f16, D>>
TryMul<f32> for Tensor<S, half::f16, D, T>
{
type Output = Self;
fn try_mul(self, rhs: f32) -> Result<Self, Self::Err> {
let scalar = half::f16::from_f32(rhs);
try_unary_op(ScalarMulKernelOp { scalar }, self)
}
}

#[cfg(feature = "f16")]
impl<
S: Shape,
D: UnaryKernel<
ScalarMulKernelOp<crate::dtypes::AMP<half::f16>>,
crate::dtypes::AMP<half::f16>,
>,
T: Tape<crate::dtypes::AMP<half::f16>, D>,
> TryMul<f32> for Tensor<S, crate::dtypes::AMP<half::f16>, D, T>
impl<S: Shape, E: Dtype, Rhs: Into<f64>, D, T: Tape<E, D>> TryMul<Rhs> for Tensor<S, E, D, T>
where
D: UnaryKernel<ScalarMulKernelOp<E>, E>,
{
type Output = Self;
fn try_mul(self, rhs: f32) -> Result<Self, Self::Err> {
let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs));
fn try_mul(self, rhs: Rhs) -> Result<Self, Self::Err> {
let rhs: f64 = rhs.into();
let scalar: E = E::from_f64(rhs).unwrap();
try_unary_op(ScalarMulKernelOp { scalar }, self)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Tensor<S, E, D, T> {
.retaped::<T>()
.try_square()?
.try_mean::<_, Ax>()?
.try_add(E::from_f64(epsilon.into()).unwrap())?
.try_add(epsilon)?
.try_sqrt()?;
centered.try_div(std.try_broadcast_like(&shape)?)
}
Expand Down
4 changes: 1 addition & 3 deletions src/tensor_ops/stddev_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> StddevTo<E> for Tensor<S,
where
Self::Shape: HasAxes<Ax> + ReduceShapeTo<Dst, Ax>,
{
self.try_var()?
.try_add(E::from_f64(epsilon.into()).unwrap())?
.try_sqrt()
self.try_var()?.try_add(epsilon)?.try_sqrt()
}
}

Expand Down
37 changes: 6 additions & 31 deletions src/tensor_ops/sub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,39 +63,14 @@ where
}
}

impl<S: Shape, E: Dtype, D: UnaryKernel<ScalarSubKernelOp<E>, E>, T: Tape<E, D>> TrySub<E>
for Tensor<S, E, D, T>
{
type Output = Self;
fn try_sub(self, rhs: E) -> Result<Self, Self::Err> {
try_unary_op(ScalarSubKernelOp { scalar: rhs }, self)
}
}

#[cfg(feature = "f16")]
impl<S: Shape, D: UnaryKernel<ScalarSubKernelOp<half::f16>, half::f16>, T: Tape<half::f16, D>>
TrySub<f32> for Tensor<S, half::f16, D, T>
{
type Output = Self;
fn try_sub(self, rhs: f32) -> Result<Self, Self::Err> {
let scalar = half::f16::from_f32(rhs);
try_unary_op(ScalarSubKernelOp { scalar }, self)
}
}

#[cfg(feature = "f16")]
impl<
S: Shape,
D: UnaryKernel<
ScalarSubKernelOp<crate::dtypes::AMP<half::f16>>,
crate::dtypes::AMP<half::f16>,
>,
T: Tape<crate::dtypes::AMP<half::f16>, D>,
> TrySub<f32> for Tensor<S, crate::dtypes::AMP<half::f16>, D, T>
impl<S: Shape, E: Dtype, Rhs: Into<f64>, D, T: Tape<E, D>> TrySub<Rhs> for Tensor<S, E, D, T>
where
D: UnaryKernel<ScalarSubKernelOp<E>, E>,
{
type Output = Self;
fn try_sub(self, rhs: f32) -> Result<Self, Self::Err> {
let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs));
fn try_sub(self, rhs: Rhs) -> Result<Self, Self::Err> {
let rhs: f64 = rhs.into();
let scalar = E::from_f64(rhs).unwrap();
try_unary_op(ScalarSubKernelOp { scalar }, self)
}
}
Expand Down

0 comments on commit 9c528eb

Please sign in to comment.