diff --git a/Cargo.toml b/Cargo.toml index e6d9eb60..68cc915c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,9 +3,9 @@ members = ["dfdx-core", "dfdx-derives", "dfdx"] resolver = "2" [workspace.dependencies] -num-traits = { version = "0.2.15", default-features = false } -safetensors = { version = "0.3.3", default-features = false } -memmap2 = { version = "0.5", default-features = false } +num-traits = { version = "0.2.17", default-features = false } +safetensors = { version = "0.4.0", default-features = false } +memmap2 = { version = "0.9.0", default-features = false } rand = { version = "0.8.5", default-features = false, features = ["std_rng"] } rand_distr = { version = "0.4.3", default-features = false } -libm = "0.2.7" \ No newline at end of file +libm = "0.2.8" \ No newline at end of file diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index a6d67031..f15a4d91 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -30,12 +30,12 @@ spin = { version = "0.9.8", default-features = false, features = ["spin_mutex", rand = { workspace = true } rand_distr = { workspace = true } zip = { version = "0.6.6", default-features = false, optional = true } -cudarc = { version = "0.9.13", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } +cudarc = { version = "0.9.15", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } num-traits = { workspace = true } safetensors = { workspace = true, optional = true } memmap2 = { workspace = true, optional = true } half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.15.4", default-features = false, optional = true } +gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } @@ -60,7 +60,7 @@ fast-alloc = ["std"] cuda = ["dep:cudarc", "dep:glob"] cudnn = ["cuda", "cudarc?/cudnn"] -f16 = ["dep:half", "cudarc?/f16"] +f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"] numpy = ["dep:zip", "std"] safetensors = ["dep:safetensors", "std", "dep:memmap2"] diff --git a/dfdx/src/lib.rs b/dfdx/src/lib.rs index 80bd81aa..cf4be4a0 100644 --- a/dfdx/src/lib.rs +++ b/dfdx/src/lib.rs @@ -329,9 +329,9 @@ pub(crate) mod tests { } #[cfg(feature = "f16")] - impl AssertClose for half::f16 { + impl AssertClose for crate::dtypes::f16 { type Elem = Self; - const DEFAULT_TOLERANCE: Self::Elem = half::f16::from_f32_const(1e-2); + const DEFAULT_TOLERANCE: Self::Elem = crate::dtypes::f16::from_f32_const(1e-2); fn get_far_pair(&self, rhs: &Self, tolerance: Self) -> Option<(Self, Self)> { if num_traits::Float::abs(self - rhs) > tolerance { Some((*self, *rhs))