From 6ae0b0409882737f3d1c3b960dd4d082c53b45aa Mon Sep 17 00:00:00 2001 From: Jack Baker Date: Thu, 7 Sep 2023 15:59:29 -0400 Subject: [PATCH 1/6] Changed file structure to accomodate custom eigh implementation --- grad_dft/external/__init__.py | 1 - grad_dft/external/eigh_impl.py | 116 --------------------------------- grad_dft/molecule.py | 2 +- grad_dft/utils/eigenproblem.py | 18 +++++ 4 files changed, 19 insertions(+), 118 deletions(-) delete mode 100644 grad_dft/external/eigh_impl.py create mode 100644 grad_dft/utils/eigenproblem.py diff --git a/grad_dft/external/__init__.py b/grad_dft/external/__init__.py index 9e05165..963c075 100644 --- a/grad_dft/external/__init__.py +++ b/grad_dft/external/__init__.py @@ -8,4 +8,3 @@ from grad_dft.external.density_functional_approximation_dm21.density_functional_approximation_dm21.neural_numint import ( _SystemState, ) -from grad_dft.external.eigh_impl import eigh2d diff --git a/grad_dft/external/eigh_impl.py b/grad_dft/external/eigh_impl.py deleted file mode 100644 index 9b29cea..0000000 --- a/grad_dft/external/eigh_impl.py +++ /dev/null @@ -1,116 +0,0 @@ -# Imported from https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e - -"""Versions based on 4.60 and 4.63 of https://arxiv.org/pdf/1701.00392.pdf.""" -import jax -import jax.numpy as jnp -import numpy as np - - -def _T(x): - return jnp.swapaxes(x, -1, -2) - - -def _H(x): - return jnp.conj(_T(x)) - - -def symmetrize(x): - return (x + _H(x)) / 2 - - -def standardize_angle(w, b): - if jnp.isrealobj(w): - return w * jnp.sign(w[0, :]) - else: - # scipy does this: makes imag(b[0] @ w) = 1 - assert not jnp.isrealobj(b) - bw = b[0] @ w - factor = bw / jnp.abs(bw) - w = w / factor[None, :] - sign = jnp.sign(w.real[0]) - w = w * sign - return w - - -@jax.custom_jvp # jax.scipy.linalg.eigh doesn't support general problem i.e. b not None -def eigh2d(a, b): - """ - Compute the solution to the symmetrized generalized eigenvalue problem. - a_s @ w = b_s @ w @ np.diag(v) - where a_s = (a + a.H) / 2, b_s = (b + b.H) / 2 are the symmetrized versions of the - inputs and H is the Hermitian (conjugate transpose) operator. - For self-adjoint inputs the solution should be consistent with `scipy.linalg.eigh` - i.e. - ```python - v, w = eigh(a, b) - v_sp, w_sp = scipy.linalg.eigh(a, b) - np.testing.assert_allclose(v, v_sp) - np.testing.assert_allclose(w, standardize_angle(w_sp)) - ``` - Note this currently uses `jax.linalg.eig(jax.linalg.solve(b, a))`, which will be - slow because there is no GPU implementation of `eig` and it's just a generally - inefficient way of doing it. Future implementations should wrap cuda primitives. - This implementation is provided primarily as a means to test `eigh_jvp_rule`. - Args: - a: [n, n] float self-adjoint matrix (i.e. conj(transpose(a)) == a) - b: [n, n] float self-adjoint matrix (i.e. conj(transpose(b)) == b) - Returns: - v: eigenvalues of the generalized problem in ascending order. - w: eigenvectors of the generalized problem, normalized such that - w.H @ b @ w = I. - """ - a = symmetrize(a) - b = symmetrize(b) - b_inv_a = jax.scipy.linalg.cho_solve(jax.scipy.linalg.cho_factor(b), a) - v, w = jax.jit(jax.numpy.linalg.eig, backend="cpu")(b_inv_a) - v = v.real - # with loops.Scope() as s: - # for _ in s.cond_range(jnp.isrealobj) - if jnp.isrealobj(a) and jnp.isrealobj(b): - w = w.real - # reorder as ascending in w - order = jnp.argsort(v) - v = v.take(order, axis=0) - w = w.take(order, axis=1) - # renormalize so v.H @ b @ H == 1 - norm2 = jax.vmap(lambda wi: (wi.conj() @ b @ wi).real, in_axes=1)(w) - norm = jnp.sqrt(norm2) - w = w / norm - w = standardize_angle(w, b) - return v, w - - -@eigh2d.defjvp -def eigh_jvp_rule(primals, tangents): - """ - Derivation based on Boedekker et al. - https://arxiv.org/pdf/1701.00392.pdf - Note diagonal entries of Winv dW/dt != 0 as they claim. - """ - a, b = primals - da, db = tangents - if not all(jnp.isrealobj(x) for x in (a, b, da, db)): - raise NotImplementedError("jvp only implemented for real inputs.") - da = symmetrize(da) - db = symmetrize(db) - - v, w = eigh2d(a, b) - - # compute only the diagonal entries - dv = jax.vmap( - lambda vi, wi: -wi.conj() @ db @ wi * vi + wi.conj() @ da @ wi, - in_axes=(0, 1), - )(v, w) - - dv = dv.real - - E = v[jnp.newaxis, :] - v[:, jnp.newaxis] - - # diagonal entries: compute as column then put into diagonals - diags = jnp.diag(-0.5 * jax.vmap(lambda wi: wi.conj() @ db @ wi, in_axes=1)(w)) - # off-diagonals: there will be NANs on the diagonal, but these aren't used - off_diags = jnp.reciprocal(E) * (_H(w) @ (da @ w - db @ w * v[jnp.newaxis, :])) - - dw = w @ jnp.where(jnp.eye(a.shape[0], dtype=np.bool), diags, off_diags) - - return (v, w), (dv, dw) diff --git a/grad_dft/molecule.py b/grad_dft/molecule.py index dfafa09..8261aec 100644 --- a/grad_dft/molecule.py +++ b/grad_dft/molecule.py @@ -16,7 +16,7 @@ from dataclasses import fields from grad_dft.utils import Array, Scalar, PyTree, vmap_chunked from functools import partial -from grad_dft.external.eigh_impl import eigh2d +from grad_dft.utils.eigenproblem import safe_general_eigh from jax import numpy as jnp from jax import scipy as jsp diff --git a/grad_dft/utils/eigenproblem.py b/grad_dft/utils/eigenproblem.py new file mode 100644 index 0000000..8230397 --- /dev/null +++ b/grad_dft/utils/eigenproblem.py @@ -0,0 +1,18 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +from jax import custom_vjp + From c25a8bf9ecb169af3ca09297338b905fa74c57b3 Mon Sep 17 00:00:00 2001 From: Jack Baker Date: Fri, 8 Sep 2023 16:08:18 -0400 Subject: [PATCH 2/6] eigensolver operation controlled by a couple of constants for DEGEN_TOL and BROADEDING. Done this way as messing with these can also mess with stability --- grad_dft/evaluate.py | 12 +++---- grad_dft/utils/__init__.py | 1 + grad_dft/utils/eigenproblem.py | 63 ++++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 6 deletions(-) diff --git a/grad_dft/evaluate.py b/grad_dft/evaluate.py index 9b06f65..afa7c91 100644 --- a/grad_dft/evaluate.py +++ b/grad_dft/evaluate.py @@ -32,9 +32,9 @@ from grad_dft.utils import PyTree, Array, Scalar, Optimizer from grad_dft.functional import Functional -from grad_dft.molecule import Molecule, abs_clip, make_rdm1, orbital_grad, general_eigh +from grad_dft.molecule import Molecule, abs_clip, make_rdm1, orbital_grad from grad_dft.train import molecule_predictor -from grad_dft.utils import PyTree, Array, Scalar +from grad_dft.utils import PyTree, Array, Scalar, safe_fock_solver from grad_dft.interface.pyscf import ( generate_chi_tensor, mol_from_Molecule, @@ -117,7 +117,7 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca else: # Diagonalize Fock matrix overlap = abs_clip(molecule.s1e, 1e-20) - mo_energy, mo_coeff = general_eigh(fock, overlap) + mo_energy, mo_coeff = safe_fock_solver(fock, overlap) molecule = molecule.replace(mo_coeff=mo_coeff) molecule = molecule.replace(mo_energy=mo_energy) @@ -275,7 +275,7 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca ) # Diagonalize Fock matrix - mo_energy, mo_coeff = general_eigh(fock, molecule.s1e) + mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e) molecule = molecule.replace(mo_coeff=mo_coeff) molecule = molecule.replace(mo_energy=mo_energy) @@ -372,7 +372,7 @@ def nelec_cost_fn(m, mo_es, sigma, _nelectron): if abs(predicted_e - old_e) * Hartree2kcalmol < e_conv and norm_gorb < g_conv: # We perform an extra diagonalization to remove the level shift # Solve eigenvalue problem - mo_energy, mo_coeff = general_eigh(fock, molecule.s1e) + mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e) molecule = molecule.replace(mo_coeff=mo_coeff) molecule = molecule.replace(mo_energy=mo_energy) @@ -662,7 +662,7 @@ def loop_body(cycle, state): fock, diis_data = diis.run(new_data, diis_data, cycle) # Diagonalize Fock matrix - mo_energy, mo_coeff = general_eigh(fock, molecule.s1e) + mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e) molecule = molecule.replace(mo_coeff=mo_coeff) molecule = molecule.replace(mo_energy=mo_energy) diff --git a/grad_dft/utils/__init__.py b/grad_dft/utils/__init__.py index 6bf7f82..32c02b3 100644 --- a/grad_dft/utils/__init__.py +++ b/grad_dft/utils/__init__.py @@ -28,3 +28,4 @@ from .tree import tree_size, tree_isfinite, tree_randn_like, tree_func, tree_shape from .utils import to_device_arrays, Utils from .chunk import vmap_chunked +from .eigenproblem import safe_fock_solver diff --git a/grad_dft/utils/eigenproblem.py b/grad_dft/utils/eigenproblem.py index 8230397..cdfc80a 100644 --- a/grad_dft/utils/eigenproblem.py +++ b/grad_dft/utils/eigenproblem.py @@ -15,4 +15,67 @@ import jax import jax.numpy as jnp from jax import custom_vjp +from .types import Array, Scalar +# Probably don't alter these unless you know what you're doing +DEGEN_TOL = 1e-6 +BROADENING = 1e-7 + +@custom_vjp +def safe_eigh(A: Array): + evecs, evals = jnp.linalg.eigh(A) + return evecs.real, evals.real + +def safe_eigh_fwd(A: Array): + evecs, evals = safe_eigh(A) + return (evecs, evals), ((evecs, evals), A) + +def safe_eigh_rev(res: Array, g: Array): + """Apply lorentzian broadening to degenerate eigenvalues/eignvectors + """ + (evals, evecs), A = res + grad_evals, grad_evecs = g + grad_evals_diag = jnp.diag(grad_evals) + evecs_trans = evecs.T + + # Generate eigenvalue difference matrix + eval_diff = evals.reshape((1, -1)) - evals.reshape((-1, 1)) + # Find elements where degen_tol condition was or wasn't was met + mask_degen = (jnp.abs(eval_diff) < DEGEN_TOL).astype(jnp.int32) + mask_non_degen = (jnp.abs(eval_diff) >= DEGEN_TOL).astype(jnp.int32) + + # Regular gap for non_degen terms => 1/(e_j - e_i) + # Will get +infs turning to large numbers here if degenerarcies are present. + # This doesn't matters as they multiply by 0 in the forthcoming mask when calculating + # the F-matrix + regular_gap = jnp.nan_to_num(jnp.divide(1, eval_diff)) + + # Lorentzian broadened gap for degen terms => (e_j - e_i)/((e_j - e_i)^2 + eps) + broadened_gap = eval_diff / (eval_diff*eval_diff + BROADENING) + + # Calculate full F matrix. large numbers generated by NaNs from regular_gap are deleted here + F = 0.5*(jnp.multiply(mask_non_degen, regular_gap) + jnp.multiply(mask_degen, broadened_gap)) + + # Set diagonals to 0 + F = F.at[jnp.diag_indices_from(F)].set(0) + + # Calculate the gradient + grad = jnp.linalg.inv(evecs_trans) @ (0.5*grad_evals_diag + jnp.multiply(F, evecs_trans @ grad_evecs)) @ evecs_trans + # Symmetrize + grad_sym = grad + grad.T + return grad_sym, + +safe_eigh.defvjp(safe_eigh_fwd, safe_eigh_rev) + +def safe_general_eigh(A: Array, B: Array): + L = jnp.linalg.cholesky(B) + L_inv = jnp.linalg.inv(L) + C = L_inv @ A @ L_inv.T + eigenvalues, eigenvectors_transformed = safe_eigh(C) + eigenvectors_original = L_inv.T @ eigenvectors_transformed + return eigenvalues, eigenvectors_original + +def safe_fock_solver(fock: Array, overlap: Array): + mo_energies_up, mo_coeffs_up = safe_general_eigh(fock[0], overlap) + mo_energies_dn, mo_coeffs_dn = safe_general_eigh(fock[1], overlap) + return jnp.stack((mo_energies_up, mo_energies_dn), axis=0), jnp.stack((mo_coeffs_up, mo_coeffs_dn), axis=0) \ No newline at end of file From f1cd50c6a7a77976d0b27649415400a82aeae80c Mon Sep 17 00:00:00 2001 From: Jack Baker Date: Fri, 8 Sep 2023 16:47:28 -0400 Subject: [PATCH 3/6] docstrings and black formatting --- grad_dft/utils/eigenproblem.py | 103 ++++++++++++++++++++++++++------- 1 file changed, 83 insertions(+), 20 deletions(-) diff --git a/grad_dft/utils/eigenproblem.py b/grad_dft/utils/eigenproblem.py index cdfc80a..f383bf5 100644 --- a/grad_dft/utils/eigenproblem.py +++ b/grad_dft/utils/eigenproblem.py @@ -21,53 +21,104 @@ DEGEN_TOL = 1e-6 BROADENING = 1e-7 + @custom_vjp -def safe_eigh(A: Array): +def safe_eigh(A: Array) -> tuple[Array, Array]: + r"""Get the eigenvalues and eigenvectors for an input real symmetric matrix. + A safe reverse mode gradient is implemented in safe_eigh_rev below. + + Args: + A (Array): a 2D Jax array representing a real symmetric matrix. + + Returns: + tuple[Array, Array]: the eigenvalues and eigenvectors of the input real symmetric matrix. + """ evecs, evals = jnp.linalg.eigh(A) - return evecs.real, evals.real + return evecs, evals + + +def safe_eigh_fwd(A: Array) -> tuple[tuple[Array, Array], tuple[tuple[Array, Array], Array]]: + r"""Forward mode operation of safe_eigh. Saves evecs and evals for the reverse pass. -def safe_eigh_fwd(A: Array): + Args: + A (Array): a 2D Jax array representing a real symmetric matrix. + + Returns: + tuple[tuple[Array, Array], tuple[tuple[Array, Array], Array]]: eigenvectors, eigenvalues and the input real symmetric matrix A. + """ evecs, evals = safe_eigh(A) return (evecs, evals), ((evecs, evals), A) -def safe_eigh_rev(res: Array, g: Array): - """Apply lorentzian broadening to degenerate eigenvalues/eignvectors + +def safe_eigh_rev(res: tuple[tuple[Array, Array], Array], g: Array) -> tuple[Array]: + r"""Use the Lorentzian broading approach suggested in https://doi.org/10.1038/s42005-021-00568-6 + to calculate stable backward mode gradients for degenerate eigenvectors. We only apply this + technique if eigenvalues are detected to be degenerate according to the constant DEGEN_TOL + in this module. When degeneracies are detected, the are broadened according to the constant + BROADENING also defined in this module. + + Args: + res (tuple[tuple[Array, Array]): eigenvectors, eigenvales and the input real symmetric matrix A saved from the forward pass + g (Array): the gradients d[eigenvalues]/dA and d[eigenvectors]/dA + + Returns: + tuple[Array]: the matrix of reverse mode gradients. + """ (evals, evecs), A = res grad_evals, grad_evecs = g grad_evals_diag = jnp.diag(grad_evals) evecs_trans = evecs.T - + # Generate eigenvalue difference matrix eval_diff = evals.reshape((1, -1)) - evals.reshape((-1, 1)) # Find elements where degen_tol condition was or wasn't was met mask_degen = (jnp.abs(eval_diff) < DEGEN_TOL).astype(jnp.int32) - mask_non_degen = (jnp.abs(eval_diff) >= DEGEN_TOL).astype(jnp.int32) - + mask_non_degen = (jnp.abs(eval_diff) >= DEGEN_TOL).astype(jnp.int32) + # Regular gap for non_degen terms => 1/(e_j - e_i) - # Will get +infs turning to large numbers here if degenerarcies are present. - # This doesn't matters as they multiply by 0 in the forthcoming mask when calculating + # Will get +infs turning to large numbers here if degeneracies are present. + # This doesn't matter as they multiply by 0 in the forthcoming mask when calculating # the F-matrix regular_gap = jnp.nan_to_num(jnp.divide(1, eval_diff)) # Lorentzian broadened gap for degen terms => (e_j - e_i)/((e_j - e_i)^2 + eps) - broadened_gap = eval_diff / (eval_diff*eval_diff + BROADENING) - - # Calculate full F matrix. large numbers generated by NaNs from regular_gap are deleted here - F = 0.5*(jnp.multiply(mask_non_degen, regular_gap) + jnp.multiply(mask_degen, broadened_gap)) - + broadened_gap = eval_diff / (eval_diff * eval_diff + BROADENING) + + # Calculate full F matrix. large numbers generated by NaNs from regular_gap are deleted here + F = 0.5 * (jnp.multiply(mask_non_degen, regular_gap) + jnp.multiply(mask_degen, broadened_gap)) + # Set diagonals to 0 F = F.at[jnp.diag_indices_from(F)].set(0) # Calculate the gradient - grad = jnp.linalg.inv(evecs_trans) @ (0.5*grad_evals_diag + jnp.multiply(F, evecs_trans @ grad_evecs)) @ evecs_trans + grad = ( + jnp.linalg.inv(evecs_trans) + @ (0.5 * grad_evals_diag + jnp.multiply(F, evecs_trans @ grad_evecs)) + @ evecs_trans + ) # Symmetrize grad_sym = grad + grad.T - return grad_sym, + return (grad_sym,) + safe_eigh.defvjp(safe_eigh_fwd, safe_eigh_rev) -def safe_general_eigh(A: Array, B: Array): + +def safe_general_eigh(A: Array, B: Array) -> tuple[Array, Array]: + r"""Solve the general eigenproblem for the eigenvalues and eigenvectors. I.e, + . math:: + AC = ECB + for matrix of eigenvectors C and diagonal matrix of eigenvalues E. This function requires all input + matrices to real and symmetric and the matrix B to be invertible. + + Args: + A (Array): a real symmetric matrix + B (Array): another real symmetric matrix + + Returns: + tuple[Array, Array]: the eigenvalues and matrix of eigenvectors + """ L = jnp.linalg.cholesky(B) L_inv = jnp.linalg.inv(L) C = L_inv @ A @ L_inv.T @@ -75,7 +126,19 @@ def safe_general_eigh(A: Array, B: Array): eigenvectors_original = L_inv.T @ eigenvectors_transformed return eigenvalues, eigenvectors_original -def safe_fock_solver(fock: Array, overlap: Array): + +def safe_fock_solver(fock: tuple[Array, Array], overlap: Array) -> tuple[Array, Array]: + """Get the eigenenergies and molecular orbital coefficients for the + up and down fock spin matrices. + Args: + fock (tuple[Array, Array]): the up and down fock spin matrices + overlap (Array): the overlap matrix + + Returns: + tuple[Array, Array]: the eigenenergies and matrix of molecular orbital coefficients. + """ mo_energies_up, mo_coeffs_up = safe_general_eigh(fock[0], overlap) mo_energies_dn, mo_coeffs_dn = safe_general_eigh(fock[1], overlap) - return jnp.stack((mo_energies_up, mo_energies_dn), axis=0), jnp.stack((mo_coeffs_up, mo_coeffs_dn), axis=0) \ No newline at end of file + return jnp.stack((mo_energies_up, mo_energies_dn), axis=0), jnp.stack( + (mo_coeffs_up, mo_coeffs_dn), axis=0 + ) From 0577f644aed48878e6f804c297aae0fb182c85fb Mon Sep 17 00:00:00 2001 From: Jack Baker Date: Fri, 8 Sep 2023 18:06:44 -0400 Subject: [PATCH 4/6] Added tests to check safe_eigh implementation --- tests/unit/test_eigenproblem.py | 122 ++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tests/unit/test_eigenproblem.py diff --git a/tests/unit/test_eigenproblem.py b/tests/unit/test_eigenproblem.py new file mode 100644 index 0000000..ab884df --- /dev/null +++ b/tests/unit/test_eigenproblem.py @@ -0,0 +1,122 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The goal of this module is to test that the code in ~/utils/eigenproblem.py produces: + +(1) The same gradients as jnp.linalg.eigh for input real symmetric matrices with non-degenerate eigenvalues. + +(2) Gradients free of NaN's when the problem is degenerate. + +Subsequently, we only aim to test the implementation safe_eigh +""" + +from jax import config +from jax.random import PRNGKey, normal, randint +import jax.numpy as jnp +from jax import jacrev + +import numpy as np + +from grad_dft.utils.eigenproblem import safe_eigh +from grad_dft.utils import Array, Scalar + +import pytest + +config.update("jax_enable_x64", True) + +ABS_THRESH = 1e-10 +SEEDS = [1984, 1993, 1945, 2001, 10, 29, 101, 1992] +RANDOM_KEYS = [PRNGKey(seed) for seed in SEEDS] +MATRIX_SIZES = jnp.arange(2, 10) +GRAD_REV_FN_JNP = jacrev(jnp.linalg.eigh) +GRAD_REV_FN_SAFE_EIGH = jacrev(safe_eigh) + + +def rand_sym_mat(matrix_size: Scalar, rand_key: PRNGKey) -> Array: + """Generate a real symmetric matrix + + Args: + matrix_size (Scalar): the square dimensions of the real symmetric matrix to be generated. + rand_key (PRNGKey): the jax-type random key for seeding RNG. + + Returns: + Array: a random real symmetric matrix + """ + random_matrix = normal(rand_key, (matrix_size, matrix_size)) + return 0.5 * (random_matrix + random_matrix.T) + + +def generate_symmetric_matrix_with_degenerate_eigenvalue(matrix_size: Scalar, rand_key: PRNGKey): + """Generate a real symmetric matrix guaranteed to have one denegerate eigenvalue + + Args: + matrix_size (Scalar): the square dimensions of the real symmetric matrix to be generated. + rand_key (PRNGKey): the jax-type random key for seeding RNG. + + Returns: + Array: a random real symmetric matrix guaranteed to have one degenerate eigenvalue + """ + + sym_mat = rand_sym_mat(matrix_size, rand_key) + + # Add a degenerate eigenvalue by duplicating one eigenvalue + eigenvalues, eigenvectors = np.linalg.eigh(sym_mat) + index_to_duplicate = randint(rand_key, (1,), 0, matrix_size) + eigenvalues[index_to_duplicate] = eigenvalues[index_to_duplicate - 1] + + # Reconstruct the matrix with the modified eigenvalues + A = eigenvectors @ np.diag(eigenvalues) @ np.linalg.inv(eigenvectors) + + return A + + +def test_non_degen_rev_mode_jacobians() -> None: + r"""Check that the reverse mode jacobians match between the jnp.linalg.eigh implementation and our + custom safe_eigh implementation when no degeneracies are present. + Args: + None + Returns: + None + """ + for mat_size in MATRIX_SIZES: + for i, key in enumerate(RANDOM_KEYS): + sym_mat = rand_sym_mat(mat_size, key) + jac_jnp = GRAD_REV_FN_JNP(sym_mat) + jac_safe_eigh = GRAD_REV_FN_SAFE_EIGH(sym_mat) + assert jac_jnp[0] == pytest.approx( + jac_safe_eigh[0], abs=1e-10 + ), f"Reverse mode jacobian difference comparing jnp.linalg.eigh and safe_eigh for seed {SEEDS[i]} and matrix_size {mat_size} exceeds threshold: {ABS_THRESH}" + assert jac_jnp[1] == pytest.approx( + jac_safe_eigh[1], abs=1e-10 + ), f"Reverse mode jacobian difference comparing jnp.linalg.eigh and safe_eigh for seed {SEEDS[i]} and matrix_size {mat_size} exceeds threshold: {ABS_THRESH}" + + +def test_degen_rev_mode_jacobians_for_nans() -> None: + r"""Check that the reverse mode jacobian contains no NaNs when passed a symmetric real matrix with degenerate eigenvalues + + Args: + None + Returns: + None + """ + for mat_size in MATRIX_SIZES: + for i, key in enumerate(RANDOM_KEYS): + degen_sym_mat = generate_symmetric_matrix_with_degenerate_eigenvalue(mat_size, key) + jac_safe_eigh = GRAD_REV_FN_SAFE_EIGH(degen_sym_mat) + assert not jnp.isnan( + jac_safe_eigh[0] + ).any(), f"Reverse mode jacobian element 0 for safe_eigh for seed {SEEDS[i]} and matrix_size {mat_size} contained atleast one NaN when passed a matrix with degenerate eigenvalues/eigenvectors" + assert not jnp.isnan( + jac_safe_eigh[0] + ).any(), f"Reverse mode jacobian element 1 for safe_eigh for seed {SEEDS[i]} and matrix_size {mat_size} contained atleast one NaN when passed a matrix with degenerate eigenvalues/eigenvectors" From a18d5da75b042d9a10029ff688f7fc4eaafee45e Mon Sep 17 00:00:00 2001 From: Jack Baker Date: Fri, 8 Sep 2023 18:08:25 -0400 Subject: [PATCH 5/6] Added to automatic tests --- .github/workflows/install_and_test.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/install_and_test.yml b/.github/workflows/install_and_test.yml index 7656253..484bab1 100644 --- a/.github/workflows/install_and_test.yml +++ b/.github/workflows/install_and_test.yml @@ -23,6 +23,9 @@ jobs: - name: Install extra example dependencies run: | pip install -e ".[examples]" + - name: Run unit tests + run: | + pytest -v tests/unit/test_eigenproblem.py - name: Run integration tests run: | pytest -v tests/integration/test_non_xc_energy.py From 487ef3d8ac0a5ba9315d1a75344d793a1cca57a4 Mon Sep 17 00:00:00 2001 From: Jack Baker Date: Fri, 8 Sep 2023 18:13:28 -0400 Subject: [PATCH 6/6] Removed old eigh functionality --- grad_dft/molecule.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/grad_dft/molecule.py b/grad_dft/molecule.py index 8261aec..f9ad526 100644 --- a/grad_dft/molecule.py +++ b/grad_dft/molecule.py @@ -16,7 +16,6 @@ from dataclasses import fields from grad_dft.utils import Array, Scalar, PyTree, vmap_chunked from functools import partial -from grad_dft.utils.eigenproblem import safe_general_eigh from jax import numpy as jnp from jax import scipy as jsp @@ -565,27 +564,9 @@ def chunked_jvp(chi_tensor, gr_tensor, ao_tensor): return (jax.jit(chunked_jvp)(chi.transpose(3, 0, 1, 2), gr, ao)).transpose(1, 2, 3, 0) - -def eig(h, x): - e0, c0 = eigh2d(h[0], x) - e1, c1 = eigh2d(h[1], x) - return jnp.stack((e0, e1), axis=0), jnp.stack((c0, c1), axis=0) - def abs_clip(arr, threshold): return jnp.where(jnp.abs(arr) > threshold, arr, 0) -def general_eigh(A, B): - L = jnp.linalg.cholesky(B) - L_inv = jnp.linalg.inv(L) - C = L_inv @ A @ L_inv.T - C = abs_clip(C, 1e-20) - eigenvalues, eigenvectors_transformed = jnp.linalg.eigh(C) - # eigenvalues, eigenvectors_transformed = jsp.linalg.eigh(C) - eigenvectors_original = L_inv.T @ eigenvectors_transformed - eigenvectors_original = abs_clip(eigenvectors_original, 1e-20) - eigenvalues = abs_clip(eigenvalues, 1e-20) - return eigenvalues, eigenvectors_original - ######################################################################