Skip to content

Commit

Permalink
Free cache memory on TensorCache drop
Browse files Browse the repository at this point in the history
- Moved the impl of `Cache::try_empty_cache()` to `TensorCache::clear()`.
  - This can be invoked both by `Cache::try_empty_cache()` and by `drop(TensorCache)`.
- Moved the device cache ptr deallocation to `BytesPtr`, `CudaBytesPtr` (newtype over `CUdeviceptr`) and `Buffer`.
  - This is abstracted by the `CachePtr` trait.
  - Can be called by `TensorCache::clear()`.
  - This method may require some "extra" device information, such as in the cuda case. That information is held by `TensorCache`.
  • Loading branch information
swfsql committed Mar 1, 2024
1 parent 4251a77 commit 8f43ce5
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 76 deletions.
74 changes: 71 additions & 3 deletions dfdx-core/src/tensor/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,35 @@ pub(crate) struct AllocationKey {
/// valid allocation. When the last value is removed from the list, the key
/// is removed.
#[derive(Debug)]
pub(crate) struct TensorCache<Ptr> {
pub(crate) struct TensorCache<Ptr: CachePtr<DeviceDev>, DeviceDev = ()> {
pub(crate) allocations: RwLock<BTreeMap<AllocationKey, Vec<Ptr>>>,
pub(crate) enabled: RwLock<bool>,
device_dev: DeviceDev,
}

impl<Ptr> Default for TensorCache<Ptr> {
impl<Ptr: CachePtr<DeviceDev>, DeviceDev: Default> Default for TensorCache<Ptr, DeviceDev> {
fn default() -> Self {
Self {
allocations: Default::default(),
enabled: RwLock::new(false),
device_dev: DeviceDev::default(),
}
}
}

impl<Ptr> TensorCache<Ptr> {
#[allow(dead_code)]
impl<Ptr: CachePtr<DeviceDev>, DeviceDev> TensorCache<Ptr, DeviceDev> {
/// Initiate an empty [TensorCache] with a given `device_dev`.
pub(crate) fn new(device_dev: DeviceDev) -> Self {
Self {
allocations: Default::default(),
enabled: RwLock::new(false),
device_dev,
}
}
}

impl<Ptr: CachePtr<DeviceDev>, DeviceDev> TensorCache<Ptr, DeviceDev> {
/// Returns the number of allocations in the cache.
#[allow(unused)]
pub(crate) fn len(&self) -> usize {
Expand Down Expand Up @@ -183,6 +197,60 @@ impl<Ptr> TensorCache<Ptr> {
}
}

impl<Ptr: CachePtr<DeviceDev>, DeviceDev> TensorCache<Ptr, DeviceDev> {
/// Deallocates all cached memory on the device and empties the cache.
pub(crate) fn try_clear(&self) -> Result<(), crate::prelude::Error> {
let mut cache = {
#[cfg(not(feature = "no-std"))]
{
self.allocations.write().unwrap()
}
#[cfg(feature = "no-std")]
{
self.allocations.write()
}
};

for (&key, allocations) in cache.iter_mut() {
for alloc in allocations.drain(..) {
alloc.dealloc(&key, &self.device_dev);
}
}
cache.clear();
Ok(())
}
}

impl<Ptr: CachePtr<DeviceDev>, DeviceDev> Drop for TensorCache<Ptr, DeviceDev> {
fn drop(&mut self) {
self.try_clear().unwrap();
}
}

/// Functionality internalized by the pointer.
pub(crate) trait CachePtr<Dev>: Sized {
// by default no deallocation is made for any cache ptr
// ie. they leak
/// Deallocates the memory referred by this pointer.
fn dealloc(self, _key: &AllocationKey, _dev: &Dev) {}
}

impl<Dev> CachePtr<Dev> for bool {}
impl<Dev> CachePtr<Dev> for u8 {}
impl<Dev> CachePtr<Dev> for u16 {}
impl<Dev> CachePtr<Dev> for u32 {}
impl<Dev> CachePtr<Dev> for u64 {}
impl<Dev> CachePtr<Dev> for u128 {}
impl<Dev> CachePtr<Dev> for usize {}
impl<Dev> CachePtr<Dev> for i8 {}
impl<Dev> CachePtr<Dev> for i16 {}
impl<Dev> CachePtr<Dev> for i32 {}
impl<Dev> CachePtr<Dev> for i64 {}
impl<Dev> CachePtr<Dev> for i128 {}
impl<Dev> CachePtr<Dev> for isize {}
impl<Dev> CachePtr<Dev> for f32 {}
impl<Dev> CachePtr<Dev> for f64 {}

#[cfg(test)]
mod test {
use super::*;
Expand Down
84 changes: 42 additions & 42 deletions dfdx-core/src/tensor/cpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub struct Cpu {
/// A thread safe random number generator.
pub(crate) rng: Arc<Mutex<StdRng>>,
/// A thread safe cache of memory allocations that can be reused.
pub(crate) cache: Arc<TensorCache<BytesPtr>>,
pub(crate) cache: Arc<TensorCache<BytesPtr, CpuDevice>>,
}

impl Default for Cpu {
Expand All @@ -47,14 +47,53 @@ impl Cpu {
}
}

/// Unit struct to represent information needed for managing allocations on the Cpu.
#[derive(Clone, Debug, Default)]
pub(crate) struct CpuDevice;

impl crate::tensor::cache::CachePtr<CpuDevice> for BytesPtr {
fn dealloc(self, key: &crate::tensor::cache::AllocationKey, _dev: &CpuDevice) {
assert!(key.num_bytes % key.size == 0);
assert!(key.num_bytes < isize::MAX as usize);
let len = key.num_bytes / key.size;
let cap = len;
// SAFETY:
// - "ptr must have been allocated using the global allocator, such as via the alloc::alloc function."
// - ✅ cpu uses global allocator
// - "T needs to have the same alignment as what ptr was allocated with."
// - ✅ we are matching on the alignment below
// - "The size of T times the capacity needs to be the same size as the pointer was allocated with."
// - ✅ covered by `key.num_bytes / key.size` and the `key.num_bytes % key.size == 0` assertion above
// - "length needs to be less than or equal to capacity."
// - ✅ they are equal
// - "The first length values must be properly initialized values of type T."
// - ✅ any bit pattern is valid for unsigned ints used below
// - "capacity needs to be the capacity that the pointer was allocated with."
// - ✅ handled by assertion above (key.num_bytes % key.size == 0)
// - "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset."
// - ✅ handled by assertion above
debug_assert_eq!(std::alloc::Layout::new::<u8>().align(), 1);
debug_assert_eq!(std::alloc::Layout::new::<u16>().align(), 2);
debug_assert_eq!(std::alloc::Layout::new::<u32>().align(), 4);
debug_assert_eq!(std::alloc::Layout::new::<u64>().align(), 8);
match key.alignment {
1 => unsafe { drop(Vec::from_raw_parts(self.0, len, cap)) },
2 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u16, len, cap)) },
4 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u32, len, cap)) },
8 => unsafe { drop(Vec::from_raw_parts(self.0 as *mut u64, len, cap)) },
_ => unreachable!(),
};
}
}

/// A [Vec] that can be cloned without allocating new memory.
/// When [Drop]ed it will insert it's data into the cache.
#[derive(Debug)]
pub struct CachableVec<E> {
/// The data stored in this vector.
pub(crate) data: Vec<E>,
/// A cache of memory allocations that can be reused.
pub(crate) cache: Arc<TensorCache<BytesPtr>>,
pub(crate) cache: Arc<TensorCache<BytesPtr, CpuDevice>>,
}

impl<E: Clone> Clone for CachableVec<E> {
Expand Down Expand Up @@ -166,45 +205,6 @@ impl Cache for Cpu {
}

fn try_empty_cache(&self) -> Result<(), Error> {
#[cfg(not(feature = "no-std"))]
let mut cache = self.cache.allocations.write().unwrap();
#[cfg(feature = "no-std")]
let mut cache = self.cache.allocations.write();
for (&key, allocations) in cache.iter_mut() {
assert!(key.num_bytes % key.size == 0);
assert!(key.num_bytes < isize::MAX as usize);
let len = key.num_bytes / key.size;
let cap = len;
for alloc in allocations.drain(..) {
// SAFETY:
// - "ptr must have been allocated using the global allocator, such as via the alloc::alloc function."
// - ✅ cpu uses global allocator
// - "T needs to have the same alignment as what ptr was allocated with."
// - ✅ we are matching on the alignment below
// - "The size of T times the capacity needs to be the same size as the pointer was allocated with."
// - ✅ covered by `key.num_bytes / key.size` and the `key.num_bytes % key.size == 0` assertion above
// - "length needs to be less than or equal to capacity."
// - ✅ they are equal
// - "The first length values must be properly initialized values of type T."
// - ✅ any bit pattern is valid for unsigned ints used below
// - "capacity needs to be the capacity that the pointer was allocated with."
// - ✅ handled by assertion above (key.num_bytes % key.size == 0)
// - "The allocated size in bytes must be no larger than isize::MAX. See the safety documentation of pointer::offset."
// - ✅ handled by assertion above
debug_assert_eq!(std::alloc::Layout::new::<u8>().align(), 1);
debug_assert_eq!(std::alloc::Layout::new::<u16>().align(), 2);
debug_assert_eq!(std::alloc::Layout::new::<u32>().align(), 4);
debug_assert_eq!(std::alloc::Layout::new::<u64>().align(), 8);
match key.alignment {
1 => unsafe { drop(Vec::from_raw_parts(alloc.0, len, cap)) },
2 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u16, len, cap)) },
4 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u32, len, cap)) },
8 => unsafe { drop(Vec::from_raw_parts(alloc.0 as *mut u64, len, cap)) },
_ => unreachable!(),
};
}
}
cache.clear();
Ok(())
self.cache.try_clear()
}
}
38 changes: 20 additions & 18 deletions dfdx-core/src/tensor/cuda/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct Cuda {
/// A second stream for kernels to optionally execute on.
pub(crate) par_stream: Arc<CudaStream>,
pub(crate) workspace: Arc<Mutex<CudaSlice<u8>>>,
pub(crate) cache: Arc<TensorCache<CUdeviceptr>>,
pub(crate) cache: Arc<TensorCache<CudaBytesPtr, Arc<CudaDevice>>>,
}

impl From<CublasError> for Error {
Expand Down Expand Up @@ -77,6 +77,7 @@ impl Cuda {
let cudnn = cudarc::cudnn::Cudnn::new(dev.clone())?;
let par_stream = Arc::new(dev.fork_default_stream()?);
let workspace = Arc::new(Mutex::new(dev.alloc_zeros::<u8>(1)?));
let cache = Arc::new(TensorCache::new(Arc::clone(&dev)));
Ok(Self {
cpu,
dev,
Expand All @@ -85,7 +86,7 @@ impl Cuda {
cudnn,
par_stream,
workspace,
cache: Default::default(),
cache,
})
}
}
Expand All @@ -100,7 +101,7 @@ impl Cuda {
) -> Result<CudaSlice<E>, Error> {
let data = self.cache.try_pop::<E>(len).map_or_else(
|| self.dev.alloc::<E>(len),
|ptr| Ok(self.dev.upgrade_device_ptr(ptr, len)),
|ptr| Ok(self.dev.upgrade_device_ptr(ptr.0, len)),
)?;
Ok(data)
}
Expand All @@ -122,14 +123,26 @@ impl Cuda {
}
}

/// A pointer to a bytes on the Cuda device. Used in conjunction with [TensorCache].
#[repr(transparent)]
#[derive(Clone, Debug)]
pub struct CudaBytesPtr(pub(crate) CUdeviceptr);

impl crate::tensor::cache::CachePtr<Arc<CudaDevice>> for CudaBytesPtr {
fn dealloc(self, key: &crate::tensor::cache::AllocationKey, dev: &Arc<CudaDevice>) {
let data = unsafe { dev.upgrade_device_ptr::<u8>(self.0, key.num_bytes) };
drop(data);
}
}

/// A [CudaSlice] that can be cloned without allocating new memory.
/// When [Drop]ed it will insert it's data into the cache.
#[derive(Debug)]
pub struct CachableCudaSlice<E> {
/// The actual data.
pub(crate) data: CudaSlice<E>,
/// A cache of device pointers that can be reused.
pub(crate) cache: Arc<TensorCache<CUdeviceptr>>,
pub(crate) cache: Arc<TensorCache<CudaBytesPtr, Arc<CudaDevice>>>,
}

impl<E: cudarc::driver::DeviceRepr> Clone for CachableCudaSlice<E> {
Expand All @@ -142,7 +155,7 @@ impl<E: cudarc::driver::DeviceRepr> Clone for CachableCudaSlice<E> {
// SAFETY:
// 1. we know that ptr is valid for `num_bytes` because it was registered for that.
// 2. we are about to set the memory with dtod_copy
let mut slice = unsafe { dev.upgrade_device_ptr(ptr, len) };
let mut slice = unsafe { dev.upgrade_device_ptr(ptr.0, len) };
dev.dtod_copy(&self.data, &mut slice).unwrap();
slice
},
Expand Down Expand Up @@ -209,7 +222,7 @@ impl<E> Drop for CachableCudaSlice<E> {
let numel = data.len();
// Get access to the raw pointer without freeing it.
let ptr = data.leak();
self.cache.insert::<E>(numel, ptr);
self.cache.insert::<E>(numel, CudaBytesPtr(ptr));
}
}
}
Expand All @@ -232,18 +245,7 @@ impl Cache for Cuda {
}

fn try_empty_cache(&self) -> Result<(), Error> {
#[cfg(not(feature = "no-std"))]
let mut cache = self.cache.allocations.write().unwrap();
#[cfg(feature = "no-std")]
let mut cache = self.cache.allocations.write();
for (&key, allocations) in cache.iter_mut() {
for alloc in allocations.drain(..) {
let data = unsafe { self.dev.upgrade_device_ptr::<u8>(alloc, key.num_bytes) };
drop(data);
}
}
cache.clear();
Ok(())
self.cache.try_clear()
}
}

Expand Down
26 changes: 13 additions & 13 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub struct Webgpu {
pub(crate) dev: Arc<Device>,
pub(crate) queue: Arc<Queue>,

pub(crate) cache: Arc<TensorCache<Buffer>>,
pub(crate) cache: Arc<TensorCache<Buffer, WebGpuDevice>>,
pub(crate) cs_cache: Arc<RwLock<HashMap<TypeId, Arc<ShaderModule>>>>,
}

Expand Down Expand Up @@ -297,12 +297,22 @@ impl Webgpu {
// }
}

/// Unit struct to represent information needed for managing allocations on the WebGpu.
#[derive(Clone, Debug, Default)]
pub(crate) struct WebGpuDevice;

impl crate::tensor::cache::CachePtr<WebGpuDevice> for Buffer {
fn dealloc(self, _key: &crate::tensor::cache::AllocationKey, _dev: &WebGpuDevice) {
drop(self)
}
}

#[derive(Debug)]
pub struct CachableBuffer<E> {
pub(crate) dev: Arc<Device>,
pub(crate) queue: Arc<Queue>,
pub(crate) data: Buffer,
pub(crate) cache: Arc<TensorCache<Buffer>>,
pub(crate) cache: Arc<TensorCache<Buffer, WebGpuDevice>>,
pub(crate) _phantom: PhantomData<E>,
}

Expand Down Expand Up @@ -397,17 +407,7 @@ impl Cache for Webgpu {
}

fn try_empty_cache(&self) -> Result<(), Error> {
#[cfg(not(feature = "no-std"))]
let mut cache = self.cache.allocations.write().unwrap();
#[cfg(feature = "no-std")]
let mut cache = self.cache.allocations.write();
for (&_key, allocations) in cache.iter_mut() {
for alloc in allocations.drain(..) {
drop(alloc);
}
}
cache.clear();
Ok(())
self.cache.try_clear()
}
}

Expand Down

0 comments on commit 8f43ce5

Please sign in to comment.