diff --git a/examples/02-ops.rs b/examples/02-ops.rs index 6e661e729..c940b9b16 100644 --- a/examples/02-ops.rs +++ b/examples/02-ops.rs @@ -34,13 +34,8 @@ fn main() { dbg!(f.array()); // and of course you can chain all of these together - let _ = dev - .sample_normal::>() - .clamp(-1.0, 1.0) - .exp() - .abs() - .powf(0.5) - / 2.0; + let _: Tensor<(Const<5>, Const<10>), f32, _> = + dev.sample_normal().clamp(-1.0, 1.0).exp().abs().powf(0.5) / 2.0; // binary and unary operations can also be performed on dynamically sized tensors let mut a: Tensor<(Const<3>, usize), f32, _> = dev.sample_uniform_like(&(Const, 5)); diff --git a/src/losses.rs b/src/losses.rs index 8f1292ca2..ea58a41da 100644 --- a/src/losses.rs +++ b/src/losses.rs @@ -62,10 +62,9 @@ pub fn huber_loss, T: Tape>( pub fn smooth_l1_loss, T: Tape>( pred: Tensor, targ: Tensor, - delta: impl Into, + delta: impl Copy + Into, ) -> Tensor { - let delta: f64 = delta.into(); - huber_loss(pred, targ, delta) / E::from_f64(delta).unwrap() + huber_loss(pred, targ, delta) / delta } /// [Cross entropy loss](https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_loss_function_and_logistic_regression). @@ -83,7 +82,6 @@ pub fn cross_entropy_with_logits_loss, T: Tape< target_probs: Tensor, ) -> Tensor { let inv_last_axis_numel = 1.0 / >::size(logits.shape()) as f64; - let inv_last_axis_numel = E::from_f64(inv_last_axis_numel).unwrap(); let probs = logits.log_softmax::(); (probs * target_probs).mean().negate() / inv_last_axis_numel } @@ -103,7 +101,6 @@ pub fn kl_div_with_logits_loss, T: Tape>( target_probs: Tensor, ) -> Tensor { let inv_last_axis_numel = 1.0 / >::size(logits.shape()) as f64; - let inv_last_axis_numel = E::from_f64(inv_last_axis_numel).unwrap(); let probs = logits.log_softmax::(); ((probs - target_probs.clone().ln()) * target_probs) .mean() diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index 400da5590..89d52c790 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -51,9 +51,7 @@ where var.try_axpy(1.0 - momentum, &var_chan, momentum * n / (n - 1.0))?; // statistics for normalizing - on tape - let std = var_chan - .try_add(E::from_f64(epsilon).unwrap())? - .try_sqrt()?; + let std = var_chan.try_add(epsilon)?.try_sqrt()?; // record broadcast of scale & bias - on tape let scale = scale @@ -81,10 +79,7 @@ where let shape = *x.shape(); // statistics for normalizing - let std = var - .clone() - .try_add(E::from_f64(epsilon).unwrap())? - .try_sqrt()?; + let std = var.clone().try_add(epsilon)?.try_sqrt()?; let scale = scale.clone().try_div(std)?.try_broadcast_like(&shape)?; diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index fe5dca347..b73901547 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -168,7 +168,7 @@ where let q = q.try_permute::<_, Axes4<0, 2, 1, 3>>()?; // Get weights - let scalar: E = E::from_f64(1.0 / ((K / H) as f64).sqrt()).unwrap(); + let scalar = 1.0 / ((K / H) as f64).sqrt(); let weights = q.try_matmul(k)?.try_mul(scalar)?; let weights = weights.try_softmax::>()?; diff --git a/src/tensor_ops/add/mod.rs b/src/tensor_ops/add/mod.rs index 8dec524e2..fe7d9cd8c 100644 --- a/src/tensor_ops/add/mod.rs +++ b/src/tensor_ops/add/mod.rs @@ -67,42 +67,15 @@ where } } -impl, E>, T: Tape> TryAdd - for Tensor -{ - type Output = Self; - /// See [add] - fn try_add(self, rhs: E) -> Result { - try_unary_op(ScalarAddKernelOp { scalar: rhs }, self) - } -} - -#[cfg(feature = "f16")] -impl, half::f16>, T: Tape> - TryAdd for Tensor -{ - type Output = Self; - /// See [add] - fn try_add(self, rhs: f32) -> Result { - let scalar = half::f16::from_f32(rhs); - try_unary_op(ScalarAddKernelOp { scalar }, self) - } -} - -#[cfg(feature = "f16")] -impl< - S: Shape, - D: UnaryKernel< - ScalarAddKernelOp>, - crate::dtypes::AMP, - >, - T: Tape, D>, - > TryAdd for Tensor, D, T> +impl, D, T: Tape> TryAdd for Tensor +where + D: UnaryKernel, E>, { type Output = Self; /// See [add] - fn try_add(self, rhs: f32) -> Result { - let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs)); + fn try_add(self, rhs: Rhs) -> Result { + let rhs: f64 = rhs.into(); + let scalar = E::from_f64(rhs).unwrap(); try_unary_op(ScalarAddKernelOp { scalar }, self) } } diff --git a/src/tensor_ops/div/mod.rs b/src/tensor_ops/div/mod.rs index 4596ae313..6bcf93fd6 100644 --- a/src/tensor_ops/div/mod.rs +++ b/src/tensor_ops/div/mod.rs @@ -65,42 +65,15 @@ where } } -impl, E>, T: Tape> TryDiv - for Tensor -{ - type Output = Self; - /// See [div] - fn try_div(self, rhs: E) -> Result { - try_unary_op(ScalarDivKernelOp { scalar: rhs }, self) - } -} - -#[cfg(feature = "f16")] -impl, half::f16>, T: Tape> - TryDiv for Tensor -{ - type Output = Self; - /// See [div] - fn try_div(self, rhs: f32) -> Result { - let scalar = half::f16::from_f32(rhs); - try_unary_op(ScalarDivKernelOp { scalar }, self) - } -} - -#[cfg(feature = "f16")] -impl< - S: Shape, - D: UnaryKernel< - ScalarDivKernelOp>, - crate::dtypes::AMP, - >, - T: Tape, D>, - > TryDiv for Tensor, D, T> +impl, D, T: Tape> TryDiv for Tensor +where + D: UnaryKernel, E>, { type Output = Self; /// See [div] - fn try_div(self, rhs: f32) -> Result { - let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs)); + fn try_div(self, rhs: Rhs) -> Result { + let rhs: f64 = rhs.into(); + let scalar = E::from_f64(rhs).unwrap(); try_unary_op(ScalarDivKernelOp { scalar }, self) } } diff --git a/src/tensor_ops/mean_to.rs b/src/tensor_ops/mean_to.rs index 05c3676f8..c87632cae 100644 --- a/src/tensor_ops/mean_to.rs +++ b/src/tensor_ops/mean_to.rs @@ -40,8 +40,7 @@ impl, T: Tape> MeanTo for Tensor + ReduceShapeTo, { let num_elements_reduced = >::size(self.shape()) as f64; - let inv_normalize = E::from_f64(1.0 / num_elements_reduced).unwrap(); - self.try_sum()?.try_mul(inv_normalize) + self.try_sum()?.try_mul(1.0 / num_elements_reduced) } } diff --git a/src/tensor_ops/mul/mod.rs b/src/tensor_ops/mul/mod.rs index 4aae683bb..4b66bfbcb 100644 --- a/src/tensor_ops/mul/mod.rs +++ b/src/tensor_ops/mul/mod.rs @@ -62,39 +62,14 @@ where } } -impl, E>, T: Tape> TryMul - for Tensor -{ - type Output = Self; - fn try_mul(self, rhs: E) -> Result { - try_unary_op(ScalarMulKernelOp { scalar: rhs }, self) - } -} - -#[cfg(feature = "f16")] -impl, half::f16>, T: Tape> - TryMul for Tensor -{ - type Output = Self; - fn try_mul(self, rhs: f32) -> Result { - let scalar = half::f16::from_f32(rhs); - try_unary_op(ScalarMulKernelOp { scalar }, self) - } -} - -#[cfg(feature = "f16")] -impl< - S: Shape, - D: UnaryKernel< - ScalarMulKernelOp>, - crate::dtypes::AMP, - >, - T: Tape, D>, - > TryMul for Tensor, D, T> +impl, D, T: Tape> TryMul for Tensor +where + D: UnaryKernel, E>, { type Output = Self; - fn try_mul(self, rhs: f32) -> Result { - let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs)); + fn try_mul(self, rhs: Rhs) -> Result { + let rhs: f64 = rhs.into(); + let scalar: E = E::from_f64(rhs).unwrap(); try_unary_op(ScalarMulKernelOp { scalar }, self) } } diff --git a/src/tensor_ops/normalize.rs b/src/tensor_ops/normalize.rs index 3e929dcf5..019e705a7 100644 --- a/src/tensor_ops/normalize.rs +++ b/src/tensor_ops/normalize.rs @@ -46,7 +46,7 @@ impl, T: Tape> Tensor { .retaped::() .try_square()? .try_mean::<_, Ax>()? - .try_add(E::from_f64(epsilon.into()).unwrap())? + .try_add(epsilon)? .try_sqrt()?; centered.try_div(std.try_broadcast_like(&shape)?) } diff --git a/src/tensor_ops/stddev_to.rs b/src/tensor_ops/stddev_to.rs index 5c2811c64..0b38796b3 100644 --- a/src/tensor_ops/stddev_to.rs +++ b/src/tensor_ops/stddev_to.rs @@ -38,9 +38,7 @@ impl, T: Tape> StddevTo for Tensor + ReduceShapeTo, { - self.try_var()? - .try_add(E::from_f64(epsilon.into()).unwrap())? - .try_sqrt() + self.try_var()?.try_add(epsilon)?.try_sqrt() } } diff --git a/src/tensor_ops/sub/mod.rs b/src/tensor_ops/sub/mod.rs index da82c575a..af58b74ef 100644 --- a/src/tensor_ops/sub/mod.rs +++ b/src/tensor_ops/sub/mod.rs @@ -63,39 +63,14 @@ where } } -impl, E>, T: Tape> TrySub - for Tensor -{ - type Output = Self; - fn try_sub(self, rhs: E) -> Result { - try_unary_op(ScalarSubKernelOp { scalar: rhs }, self) - } -} - -#[cfg(feature = "f16")] -impl, half::f16>, T: Tape> - TrySub for Tensor -{ - type Output = Self; - fn try_sub(self, rhs: f32) -> Result { - let scalar = half::f16::from_f32(rhs); - try_unary_op(ScalarSubKernelOp { scalar }, self) - } -} - -#[cfg(feature = "f16")] -impl< - S: Shape, - D: UnaryKernel< - ScalarSubKernelOp>, - crate::dtypes::AMP, - >, - T: Tape, D>, - > TrySub for Tensor, D, T> +impl, D, T: Tape> TrySub for Tensor +where + D: UnaryKernel, E>, { type Output = Self; - fn try_sub(self, rhs: f32) -> Result { - let scalar = crate::dtypes::AMP(half::f16::from_f32(rhs)); + fn try_sub(self, rhs: Rhs) -> Result { + let rhs: f64 = rhs.into(); + let scalar = E::from_f64(rhs).unwrap(); try_unary_op(ScalarSubKernelOp { scalar }, self) } }