diff --git a/src/tensor_ops/adam/mod.rs b/src/tensor_ops/adam/mod.rs index 5f847d2a1..0188f8aeb 100644 --- a/src/tensor_ops/adam/mod.rs +++ b/src/tensor_ops/adam/mod.rs @@ -72,7 +72,7 @@ impl AdamConfig { param.device.adam_kernel( t, self, - std::sync::Arc::get_mut(&mut param.data).unwrap(), + std::sync::Arc::make_mut(&mut param.data), moment1, moment2, grad, diff --git a/src/tensor_ops/rmsprop/mod.rs b/src/tensor_ops/rmsprop/mod.rs index ecf2598c1..b4095a2a7 100644 --- a/src/tensor_ops/rmsprop/mod.rs +++ b/src/tensor_ops/rmsprop/mod.rs @@ -69,7 +69,7 @@ impl RMSpropConfig { ) -> Result<(), D::Err> { param.device.rmsprop_kernel( self, - std::sync::Arc::get_mut(&mut param.data).unwrap(), + std::sync::Arc::make_mut(&mut param.data), momentum, square_avg, grad_avg, diff --git a/src/tensor_ops/sgd/mod.rs b/src/tensor_ops/sgd/mod.rs index bb917cc9d..be27f1393 100644 --- a/src/tensor_ops/sgd/mod.rs +++ b/src/tensor_ops/sgd/mod.rs @@ -102,7 +102,7 @@ impl SgdConfig { ) -> Result<(), D::Err> { param.device.sgd_kernel( self, - std::sync::Arc::get_mut(&mut param.data).unwrap(), + std::sync::Arc::make_mut(&mut param.data), velocity, grad, )