Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eigenvalues and eigenvectors #1334

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions docs/src/python/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ Linear Algebra
cholesky_inv
qr
svd
eigvalsh
eigh
1 change: 1 addition & 0 deletions mlx/backend/accelerate/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(EighPrimitive)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this to Eigh for consistency with other primitive names.


void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigvalsh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(EighPrimitive)

namespace {

Expand Down
184 changes: 184 additions & 0 deletions mlx/backend/common/eigvalsh.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// Copyright © 2023-2024 Apple Inc.

#include "mlx/array.h"
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"

#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif

namespace mlx::core {

namespace {

// Delegate to the eigenvalue decomposition taking into account differences in
// LAPACK implementations (basically how to pass the 'jobz' and 'uplo' strings
// to fortran).
int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it needs to be in this PR, but it would be nice to add support for arbitrary matrices (not just hermitian/symmetric ones).

Hopefully that could just be a flag on the Eig primitive that then uses a different lapack incantation?

int info;
int lwork = -1;
int liwork = -1;
float work_query;
int iwork_query;

// Query for optimal work array sizes
#ifdef LAPACK_FORTRAN_STRLEN_END
ssyevd_(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* w = */ w,
/* work = */ &work_query,
/* lwork = */ &lwork,
/* iwork = */ &iwork_query,
/* liwork = */ &liwork,
/* info = */ &info,
/* jobz_len = */ static_cast<size_t>(1),
/* uplo_len = */ static_cast<size_t>(1));
#else
ssyevd_(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* w = */ w,
/* work = */ &work_query,
/* lwork = */ &lwork,
/* iwork = */ &iwork_query,
/* liwork = */ &liwork,
/* info = */ &info);
#endif

lwork = static_cast<int>(work_query);
liwork = iwork_query;

std::vector<float> work(lwork);
std::vector<int> iwork(liwork);

// Compute eigenvalues (and optionally eigenvectors)
#ifdef LAPACK_FORTRAN_STRLEN_END
ssyevd_(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* w = */ w,
/* work = */ work.data(),
/* lwork = */ &lwork,
/* iwork = */ iwork.data(),
/* liwork = */ &liwork,
/* info = */ &info,
/* jobz_len = */ static_cast<size_t>(1),
/* uplo_len = */ static_cast<size_t>(1));
#else
ssyevd_(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* w = */ w,
/* work = */ work.data(),
/* lwork = */ &lwork,
/* iwork = */ iwork.data(),
/* liwork = */ &liwork,
/* info = */ &info);
#endif

return info;
}

} // namespace

void eigh_impl(
const array& a,
array& values,
array& vectors,
bool upper,
bool compute_eigenvectors) {
char jobz = compute_eigenvectors ? 'V' : 'N';
char uplo = upper ? 'U' : 'L';

array buffer = copy(a);

const int N = static_cast<int>(a.shape(-1));
const int num_matrices = static_cast<int>(a.size() / (N * N));

std::vector<int> values_shape = {num_matrices, N};
values = array(allocator::malloc(num_matrices * N * size_of(a.dtype())), values_shape, a.dtype());

float* matrix = buffer.data<float>();
float* w = values.data<float>();

if (compute_eigenvectors) {
std::vector<int> vectors_shape = a.shape();
vectors = array(allocator::malloc(a.size() * size_of(a.dtype())), vectors_shape, a.dtype());
}

float* vecs = compute_eigenvectors ? vectors.data<float>() : nullptr;

for (int i = 0; i < num_matrices; i++) {
int info = ssyevd_wrapper(jobz, uplo, matrix, w, N);

if (info != 0) {
std::stringstream msg;
msg << "[eigh] Eigenvalue decomposition failed with error code " << info;
throw std::runtime_error(msg.str());
}

if (compute_eigenvectors) {
// Copy eigenvectors to the output array
std::copy(matrix, matrix + N * N, vecs);
vecs += N * N;
}

matrix += N * N;
w += N;
}
}

void EighPrimitive::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Validate the number of inputs
if (inputs.size() != 1) {
throw std::invalid_argument("[EighPrimitive::eval] Expected exactly one input array.");
}

const array& input = inputs[0];

// Ensure the input array is evaluated before accessing its data
const_cast<array&>(input).eval();

// Validate the data type
Dtype input_dtype = input.dtype(); // Changed from 'dtype_t' to 'Dtype'

// Validate the number of dimensions (expecting at least 2D)
if (input.ndim() < 2) {
throw std::invalid_argument("[EighPrimitive::eval] Input array must be at least 2-dimensional.");
}

array values{};
array vectors{};
eigh_impl(input, values, vectors, upper_, compute_eigenvectors_);

// Ensure the output arrays are evaluated
values.eval();
if (compute_eigenvectors_) {
vectors.eval();
outputs = {values, vectors};
} else {
outputs = {values};
}
}

} // namespace mlx::core
4 changes: 4 additions & 0 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,10 @@ void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
}

void EighPrimitive::eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) {
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal EighPrimitive NYI.");
}

void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ NO_CPU(Divide)
NO_CPU_MULTI(DivMod)
NO_CPU(NumberOfElements)
NO_CPU(Remainder)
NO_CPU_MULTI(EighPrimitive)
NO_CPU(Equal)
NO_CPU(Erf)
NO_CPU(ErfInv)
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(EighPrimitive)
NO_GPU(View)

namespace fast {
Expand Down
66 changes: 66 additions & 0 deletions mlx/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,70 @@ array cholesky_inv(
}
}

array eigvalsh(
const array& a,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::eigvalsh] Arrays must be type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}

if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::eigvalsh] Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(
"[linalg::eigvalsh] Eigenvalues are only defined for square matrices.");
}

std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1);
out_shape.back() = a.shape(-1);

return array(
out_shape,
a.dtype(),
std::make_shared<EighPrimitive>(to_stream(s), upper, false),
{astype(a, a.dtype(), s)});
}

std::pair<array, array> eigh(
const array& a,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::eigh] Arrays must be type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}

if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::eigh] Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(
"[linalg::eigh] Eigenvectors are only defined for square matrices.");
}

auto out = array::make_arrays(
{std::vector<int>(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()},
std::make_shared<EighPrimitive>(to_stream(s), upper, true),
{astype(a, a.dtype(), s)});
return std::make_pair(out[0], out[1]);
}

} // namespace mlx::core::linalg
4 changes: 4 additions & 0 deletions mlx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,8 @@ array pinv(const array& a, StreamOrDevice s = {});

array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});

array eigvalsh(const array& a, bool upper = false, StreamOrDevice s = {});

std::pair<array, array> eigh(const array& a, bool upper = false, StreamOrDevice s = {});

} // namespace mlx::core::linalg
22 changes: 22 additions & 0 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,28 @@ std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
}

std::pair<std::vector<array>, std::vector<int>> EighPrimitive::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);

auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];

std::vector<array> outputs;
if (compute_eigenvectors_) {
auto [values, vectors] = linalg::eigh(a, upper_, stream());
outputs = {values, vectors};
} else {
outputs = {linalg::eigvalsh(a, upper_, stream())};
}

std::vector<int> out_axes(outputs.size(), ax);

return {outputs, out_axes};
}

std::vector<array> Concatenate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
Expand Down
35 changes: 35 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -2166,4 +2166,39 @@ class Cholesky : public UnaryPrimitive {
bool upper_;
};

class EighPrimitive : public Primitive {
public:
explicit EighPrimitive(Stream stream, bool upper, bool compute_eigenvectors)
: Primitive(stream), upper_(upper), compute_eigenvectors_(compute_eigenvectors) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) override;

DEFINE_VMAP()
DEFINE_PRINT(EighPrimitive)

std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
auto shape = inputs[0].shape();
shape.pop_back(); // Remove last dimension for eigenvalues
if (compute_eigenvectors_) {
return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors
} else {
return {shape}; // Only eigenvalues
}
}

bool is_equivalent(const Primitive& other) const override {
if (auto* p = dynamic_cast<const EighPrimitive*>(&other)) {
return upper_ == p->upper_ && compute_eigenvectors_ == p->compute_eigenvectors_;
}
return false;
}

private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
bool upper_;
bool compute_eigenvectors_;
};

} // namespace mlx::core
Loading