-
Notifications
You must be signed in to change notification settings - Fork 950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eigenvalues and eigenvectors #1334
Draft
kashif
wants to merge
16
commits into
ml-explore:main
Choose a base branch
from
kashif:eigenvalues
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
6735a59
initial eigvalsh
kashif f1789b3
add compute_vectors
kashif a89ee52
add compute_vectors_
kashif b5f900a
return a pair
kashif 9dd754b
Merge branch 'main' into eigenvalues
kashif 19e0148
add eigh to return only eigenvectors
kashif 3100188
fixed typo
kashif e64ca5e
Merge branch 'main' into eigenvalues
kashif 523cb3d
Merge branch 'main' into eigenvalues
kashif 97b965b
merge merge Eighvalsh and Eigh into a single primitive
kashif c0b653b
use the same primate with the flag
kashif 2383181
fix primatives
kashif dc614eb
use MULTI
kashif 5b53354
fix eval_gpu
kashif 859dd23
fix decleration
kashif dbb5c64
Merge branch 'main' into eigenvalues
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,5 @@ Linear Algebra | |
cholesky_inv | ||
qr | ||
svd | ||
eigvalsh | ||
eigh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include "mlx/array.h" | ||
#include "mlx/allocator.h" | ||
#include "mlx/backend/common/copy.h" | ||
#include "mlx/linalg.h" | ||
#include "mlx/primitives.h" | ||
|
||
#ifdef ACCELERATE_NEW_LAPACK | ||
#include <Accelerate/Accelerate.h> | ||
#else | ||
#include <lapack.h> | ||
#endif | ||
|
||
namespace mlx::core { | ||
|
||
namespace { | ||
|
||
// Delegate to the eigenvalue decomposition taking into account differences in | ||
// LAPACK implementations (basically how to pass the 'jobz' and 'uplo' strings | ||
// to fortran). | ||
int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it needs to be in this PR, but it would be nice to add support for arbitrary matrices (not just hermitian/symmetric ones). Hopefully that could just be a flag on the |
||
int info; | ||
int lwork = -1; | ||
int liwork = -1; | ||
float work_query; | ||
int iwork_query; | ||
|
||
// Query for optimal work array sizes | ||
#ifdef LAPACK_FORTRAN_STRLEN_END | ||
ssyevd_( | ||
/* jobz = */ &jobz, | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* w = */ w, | ||
/* work = */ &work_query, | ||
/* lwork = */ &lwork, | ||
/* iwork = */ &iwork_query, | ||
/* liwork = */ &liwork, | ||
/* info = */ &info, | ||
/* jobz_len = */ static_cast<size_t>(1), | ||
/* uplo_len = */ static_cast<size_t>(1)); | ||
#else | ||
ssyevd_( | ||
/* jobz = */ &jobz, | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* w = */ w, | ||
/* work = */ &work_query, | ||
/* lwork = */ &lwork, | ||
/* iwork = */ &iwork_query, | ||
/* liwork = */ &liwork, | ||
/* info = */ &info); | ||
#endif | ||
|
||
lwork = static_cast<int>(work_query); | ||
liwork = iwork_query; | ||
|
||
std::vector<float> work(lwork); | ||
std::vector<int> iwork(liwork); | ||
|
||
// Compute eigenvalues (and optionally eigenvectors) | ||
#ifdef LAPACK_FORTRAN_STRLEN_END | ||
ssyevd_( | ||
/* jobz = */ &jobz, | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* w = */ w, | ||
/* work = */ work.data(), | ||
/* lwork = */ &lwork, | ||
/* iwork = */ iwork.data(), | ||
/* liwork = */ &liwork, | ||
/* info = */ &info, | ||
/* jobz_len = */ static_cast<size_t>(1), | ||
/* uplo_len = */ static_cast<size_t>(1)); | ||
#else | ||
ssyevd_( | ||
/* jobz = */ &jobz, | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* w = */ w, | ||
/* work = */ work.data(), | ||
/* lwork = */ &lwork, | ||
/* iwork = */ iwork.data(), | ||
/* liwork = */ &liwork, | ||
/* info = */ &info); | ||
#endif | ||
|
||
return info; | ||
} | ||
|
||
} // namespace | ||
|
||
void eigh_impl( | ||
const array& a, | ||
array& values, | ||
array& vectors, | ||
bool upper, | ||
bool compute_eigenvectors) { | ||
char jobz = compute_eigenvectors ? 'V' : 'N'; | ||
char uplo = upper ? 'U' : 'L'; | ||
|
||
array buffer = copy(a); | ||
|
||
const int N = static_cast<int>(a.shape(-1)); | ||
const int num_matrices = static_cast<int>(a.size() / (N * N)); | ||
|
||
std::vector<int> values_shape = {num_matrices, N}; | ||
values = array(allocator::malloc(num_matrices * N * size_of(a.dtype())), values_shape, a.dtype()); | ||
|
||
float* matrix = buffer.data<float>(); | ||
float* w = values.data<float>(); | ||
|
||
if (compute_eigenvectors) { | ||
std::vector<int> vectors_shape = a.shape(); | ||
vectors = array(allocator::malloc(a.size() * size_of(a.dtype())), vectors_shape, a.dtype()); | ||
} | ||
|
||
float* vecs = compute_eigenvectors ? vectors.data<float>() : nullptr; | ||
|
||
for (int i = 0; i < num_matrices; i++) { | ||
int info = ssyevd_wrapper(jobz, uplo, matrix, w, N); | ||
|
||
if (info != 0) { | ||
std::stringstream msg; | ||
msg << "[eigh] Eigenvalue decomposition failed with error code " << info; | ||
throw std::runtime_error(msg.str()); | ||
} | ||
|
||
if (compute_eigenvectors) { | ||
// Copy eigenvectors to the output array | ||
std::copy(matrix, matrix + N * N, vecs); | ||
vecs += N * N; | ||
} | ||
|
||
matrix += N * N; | ||
w += N; | ||
} | ||
} | ||
|
||
void EighPrimitive::eval( | ||
const std::vector<array>& inputs, | ||
std::vector<array>& outputs) { | ||
// Validate the number of inputs | ||
if (inputs.size() != 1) { | ||
throw std::invalid_argument("[EighPrimitive::eval] Expected exactly one input array."); | ||
} | ||
|
||
const array& input = inputs[0]; | ||
|
||
// Ensure the input array is evaluated before accessing its data | ||
const_cast<array&>(input).eval(); | ||
|
||
// Validate the data type | ||
Dtype input_dtype = input.dtype(); // Changed from 'dtype_t' to 'Dtype' | ||
|
||
// Validate the number of dimensions (expecting at least 2D) | ||
if (input.ndim() < 2) { | ||
throw std::invalid_argument("[EighPrimitive::eval] Input array must be at least 2-dimensional."); | ||
} | ||
|
||
array values{}; | ||
array vectors{}; | ||
eigh_impl(input, values, vectors, upper_, compute_eigenvectors_); | ||
|
||
// Ensure the output arrays are evaluated | ||
values.eval(); | ||
if (compute_eigenvectors_) { | ||
vectors.eval(); | ||
outputs = {values, vectors}; | ||
} else { | ||
outputs = {values}; | ||
} | ||
} | ||
|
||
} // namespace mlx::core |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename this to
Eigh
for consistency with other primitive names.