From 5a722fba47f494c4a50744b742c7f66ca4311815 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Mon, 10 Jul 2023 08:16:37 -0400 Subject: [PATCH] Using Groups in conv weight init --- src/nn/conv.rs | 3 ++- src/nn/convtrans.rs | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nn/conv.rs b/src/nn/conv.rs index c53fe2d43..879836749 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -108,7 +108,8 @@ where |s| &s.weight, |s| &mut s.weight, TensorOptions::reset_with(|t| { - let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); + let scale = E::from_f64(G as f64 / (I * K * K) as f64).unwrap(); + let b = scale.sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) }), ), diff --git a/src/nn/convtrans.rs b/src/nn/convtrans.rs index da00c05c3..179d75c94 100644 --- a/src/nn/convtrans.rs +++ b/src/nn/convtrans.rs @@ -104,7 +104,8 @@ where |s| &s.weight, |s| &mut s.weight, TensorOptions::reset_with(|t| { - let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); + let scale = E::from_f64(G as f64 / (I * K * K) as f64).unwrap(); + let b = scale.sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) }), ),