Skip to content

Commit

Permalink
[Breaking] Adds AMP<F> dtype (#811)
Browse files Browse the repository at this point in the history
* Adds AMP<F> dtype

* impl sum for amp cpu

* impl amp kernels for cpu optimizers

* Moving NotMixedPrecision to dtypes

* Adding AMP implementations for cuda kernels

* Fixing cuda errors & warnings

* Adds Gemm impl for AMP<f16> for CudaBlas

* Adding chunk_sum for amp f16

* bump cudarc version

* Update src/dtypes/amp.rs

Co-authored-by: nkoppel <nathankoppel0@gmail.com>

* More generic AMP

* Fixing unused imports

---------

Co-authored-by: nkoppel <nathankoppel0@gmail.com>
  • Loading branch information
coreylowman and nkoppel committed Jul 26, 2023
1 parent 40996df commit 0b49672
Show file tree
Hide file tree
Showing 67 changed files with 1,825 additions and 191 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ spin = { version = "0.9.8", default-features = false, features = ["spin_mutex",
rand = { version = "0.8.5", default-features = false, features = ["std_rng"] }
rand_distr = { version = "0.4.3", default-features = false }
zip = { version = "0.6.6", default-features = false, optional = true }
cudarc = { version = "0.9.11", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] }
cudarc = { version = "0.9.13", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] }
num-traits = { version = "0.2.15", default-features = false }
safetensors = { version = "0.3", default-features = false, optional = true }
memmap2 = { version = "0.5", default-features = false, optional = true }
Expand Down Expand Up @@ -65,6 +65,7 @@ numpy = ["dep:zip", "std"]
safetensors = ["dep:safetensors", "std", "dep:memmap2"]

test-f16 = ["f16"]
test-amp-f16 = ["f16"]
test-f64 = []
test-integrations = []
ci-check = ["cudarc?/ci-check"]
Expand Down
Loading

0 comments on commit 0b49672

Please sign in to comment.