Skip to content

Commit

Permalink
impl Tape for Arc<Mutex<OwnedTape>> (#835)
Browse files Browse the repository at this point in the history
* impl Tape for Arc<Mutex<OwnedTape>>

* Wrapping Arc<Mutex<OwnedTape>> behind std feature
  • Loading branch information
coreylowman committed Jul 31, 2023
1 parent ed09589 commit 7763510
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
55 changes: 53 additions & 2 deletions src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,30 @@ impl<E: std::fmt::Debug, D: Storage<E>> std::fmt::Debug for OwnedTape<E, D> {
}
}

impl<E, D: Storage<E>> From<Gradients<E, D>> for OwnedTape<E, D> {
fn from(gradients: Gradients<E, D>) -> Self {
Self {
operations: Default::default(),
gradients,
}
}
}

impl<E, D: Storage<E>> OwnedTape<E, D> {
/// Compute the [Gradients]! This just runs all the operations on a new [Gradients] struct.
///
/// Note that this method takes ownership of self, so it can't be called twice!
pub(crate) fn execute(mut self) -> Result<Gradients<E, D>, D::Err> {
pub(crate) fn execute(&mut self) -> Result<Gradients<E, D>, D::Err> {
// We must ensure that the operations are sorted in execution time order.
// Otherwise an backward operation may not be executed in the right order
// if multiple tapes were merged together.
self.operations.sort_by_key(|(k, _)| *k);
// In case the same operation is present multiple times, we dedup it.
self.operations.dedup_by_key(|(k, _)| *k);
for (_, operation) in self.operations.drain(..).rev() {
(operation)(&mut self.gradients)?;
}
Ok(self.gradients)
Ok(std::mem::replace(&mut self.gradients, Gradients::leaky()))
}
}

Expand Down Expand Up @@ -282,3 +293,43 @@ impl<E, D: Storage<E>> Merge<OwnedTape<E, D>> for OwnedTape<E, D> {
self
}
}

#[cfg(feature = "std")]
impl<E, D: Storage<E>> Merge<NoneTape> for std::sync::Arc<std::sync::Mutex<OwnedTape<E, D>>> {
fn merge(self, _: NoneTape) -> Self {
self
}
}

#[cfg(feature = "std")]
impl<E, D: Storage<E>> Merge<Self> for std::sync::Arc<std::sync::Mutex<OwnedTape<E, D>>> {
fn merge(self, other: Self) -> Self {
if !std::sync::Arc::ptr_eq(&self, &other) {
let mut lhs = self.lock().unwrap();
let mut rhs = other.lock().unwrap();
lhs.gradients
.gradient_by_id
.append(&mut rhs.gradients.gradient_by_id);
if let Some(leafs) = &mut rhs.gradients.leaf_ids {
lhs.gradients
.leaf_ids
.get_or_insert_with(Default::default)
.append(leafs);
}
lhs.operations.append(&mut rhs.operations);
}
self
}
}

#[cfg(feature = "std")]
impl<E, D: Storage<E>> Tape<E, D> for std::sync::Arc<std::sync::Mutex<OwnedTape<E, D>>> {
const OWNS_TAPE: bool = true;
fn add_backward_op<F>(&mut self, operation: F)
where
F: 'static + FnOnce(&mut Gradients<E, D>) -> Result<(), D::Err>,
{
let mut tape = self.lock().unwrap();
tape.add_backward_op(operation);
}
}
18 changes: 18 additions & 0 deletions src/tensor_ops/utilities/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,21 @@ impl<E: 'static + Clone, D: OneFillStorage<E>> Backward<E, D>
Ok(grads)
}
}

#[cfg(feature = "std")]
impl<E: 'static + Clone, D: OneFillStorage<E>> Backward<E, D>
for Tensor<Rank0, E, D, std::sync::Arc<std::sync::Mutex<OwnedTape<E, D>>>>
{
fn try_backward(self) -> Result<Gradients<E, D>, Self::Err> {
let (t, tape) = self.split_tape();
let t_ghost = t.ghost();
let mut tape = tape.lock().unwrap();
tape.add_backward_op(move |grads| {
grads.try_alloc_for(&t_ghost)?;
t.device.try_fill_with_ones(grads.get_mut(&t_ghost))
});
let mut grads = tape.execute()?;
grads.drop_non_leafs();
Ok(grads)
}
}

0 comments on commit 7763510

Please sign in to comment.