diff --git a/src/nn/conv.rs b/src/nn/conv.rs index c53fe2d43..879836749 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -108,7 +108,8 @@ where |s| &s.weight, |s| &mut s.weight, TensorOptions::reset_with(|t| { - let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); + let scale = E::from_f64(G as f64 / (I * K * K) as f64).unwrap(); + let b = scale.sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) }), ), diff --git a/src/nn/convtrans.rs b/src/nn/convtrans.rs index da00c05c3..179d75c94 100644 --- a/src/nn/convtrans.rs +++ b/src/nn/convtrans.rs @@ -104,7 +104,8 @@ where |s| &s.weight, |s| &mut s.weight, TensorOptions::reset_with(|t| { - let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); + let scale = E::from_f64(G as f64 / (I * K * K) as f64).unwrap(); + let b = scale.sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) }), ),