Skip to content

Commit

Permalink
fix tests for f64 and math for no-std
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Dec 2, 2023
1 parent 28e02d3 commit e24c8d1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
8 changes: 7 additions & 1 deletion dfdx-core/src/tensor_ops/prodigy/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,13 @@ impl<E: num_traits::Float + Dtype + NotMixedPrecision> ProdigyKernel<E> for Cpu
) -> Result<(), Error> {
let mut d_denom_: E = E::zero();
let [beta1, beta2] = cfg.betas.map(E::from_f64).map(Option::unwrap);
let beta3 = E::from_f64(cfg.beta3.unwrap_or_else(|| cfg.betas[1].sqrt())).unwrap();
let beta3 = E::from_f64(cfg.beta3.unwrap_or_else(|| {
#[cfg(feature = "no-std")]
use num_traits::Float;

cfg.betas[1].sqrt()
}))
.unwrap();

let bias_correction = if cfg.use_bias_correction {
// note: in here the first k = 1, whereas on the reference python code it's 0
Expand Down
15 changes: 15 additions & 0 deletions dfdx/src/nn/optim/prodigy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,27 @@ mod tests {
.zip(expected_updates)
{
let prediction = m.forward_mut(x.trace(grads));

#[cfg(feature = "test-f64")]
assert_close_to_literal!(prediction, ey, 7e-5);
#[cfg(not(feature = "test-f64"))]
assert_close_to_literal!(prediction, ey);

let loss = crate::losses::mse_loss(prediction, dev.tensor(y));
grads = loss.backward();

#[cfg(feature = "test-f64")]
assert_close_to_literal!(grads.get(&m.weight), eg, 3e-5);
#[cfg(not(feature = "test-f64"))]
assert_close_to_literal!(grads.get(&m.weight), eg);

opt.update(&mut m, &grads).expect("");

#[cfg(feature = "test-f64")]
assert_close_to_literal!(m.weight, eu, 5e-4);
#[cfg(not(feature = "test-f64"))]
assert_close_to_literal!(m.weight, eu);

m.zero_grads(&mut grads);
}
}
Expand Down

0 comments on commit e24c8d1

Please sign in to comment.