Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable backwards gradients for eigh when passed a degenerate eigenproblem #38

Merged
merged 6 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/install_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ jobs:
- name: Install extra example dependencies
run: |
pip install -e ".[examples]"
- name: Run unit tests
run: |
pytest -v tests/unit/test_eigenproblem.py
- name: Run integration tests
run: |
pytest -v tests/integration/test_non_xc_energy.py
Expand Down
12 changes: 6 additions & 6 deletions grad_dft/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from grad_dft.utils import PyTree, Array, Scalar, Optimizer
from grad_dft.functional import Functional

from grad_dft.molecule import Molecule, abs_clip, make_rdm1, orbital_grad, general_eigh
from grad_dft.molecule import Molecule, abs_clip, make_rdm1, orbital_grad
from grad_dft.train import molecule_predictor
from grad_dft.utils import PyTree, Array, Scalar
from grad_dft.utils import PyTree, Array, Scalar, safe_fock_solver
from grad_dft.interface.pyscf import (
generate_chi_tensor,
mol_from_Molecule,
Expand Down Expand Up @@ -117,7 +117,7 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca
else:
# Diagonalize Fock matrix
overlap = abs_clip(molecule.s1e, 1e-20)
mo_energy, mo_coeff = general_eigh(fock, overlap)
mo_energy, mo_coeff = safe_fock_solver(fock, overlap)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down Expand Up @@ -275,7 +275,7 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca
)

# Diagonalize Fock matrix
mo_energy, mo_coeff = general_eigh(fock, molecule.s1e)
mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down Expand Up @@ -372,7 +372,7 @@ def nelec_cost_fn(m, mo_es, sigma, _nelectron):
if abs(predicted_e - old_e) * Hartree2kcalmol < e_conv and norm_gorb < g_conv:
# We perform an extra diagonalization to remove the level shift
# Solve eigenvalue problem
mo_energy, mo_coeff = general_eigh(fock, molecule.s1e)
mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down Expand Up @@ -662,7 +662,7 @@ def loop_body(cycle, state):
fock, diis_data = diis.run(new_data, diis_data, cycle)

# Diagonalize Fock matrix
mo_energy, mo_coeff = general_eigh(fock, molecule.s1e)
mo_energy, mo_coeff = safe_fock_solver(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down
1 change: 0 additions & 1 deletion grad_dft/external/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@
from grad_dft.external.density_functional_approximation_dm21.density_functional_approximation_dm21.neural_numint import (
_SystemState,
)
from grad_dft.external.eigh_impl import eigh2d
116 changes: 0 additions & 116 deletions grad_dft/external/eigh_impl.py

This file was deleted.

19 changes: 0 additions & 19 deletions grad_dft/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from dataclasses import fields
from grad_dft.utils import Array, Scalar, PyTree, vmap_chunked
from functools import partial
from grad_dft.external.eigh_impl import eigh2d

from jax import numpy as jnp
from jax import scipy as jsp
Expand Down Expand Up @@ -565,27 +564,9 @@ def chunked_jvp(chi_tensor, gr_tensor, ao_tensor):

return (jax.jit(chunked_jvp)(chi.transpose(3, 0, 1, 2), gr, ao)).transpose(1, 2, 3, 0)


def eig(h, x):
e0, c0 = eigh2d(h[0], x)
e1, c1 = eigh2d(h[1], x)
return jnp.stack((e0, e1), axis=0), jnp.stack((c0, c1), axis=0)

def abs_clip(arr, threshold):
return jnp.where(jnp.abs(arr) > threshold, arr, 0)

def general_eigh(A, B):
L = jnp.linalg.cholesky(B)
L_inv = jnp.linalg.inv(L)
C = L_inv @ A @ L_inv.T
C = abs_clip(C, 1e-20)
eigenvalues, eigenvectors_transformed = jnp.linalg.eigh(C)
# eigenvalues, eigenvectors_transformed = jsp.linalg.eigh(C)
eigenvectors_original = L_inv.T @ eigenvectors_transformed
eigenvectors_original = abs_clip(eigenvectors_original, 1e-20)
eigenvalues = abs_clip(eigenvalues, 1e-20)
return eigenvalues, eigenvectors_original


######################################################################

Expand Down
1 change: 1 addition & 0 deletions grad_dft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
from .tree import tree_size, tree_isfinite, tree_randn_like, tree_func, tree_shape
from .utils import to_device_arrays, Utils
from .chunk import vmap_chunked
from .eigenproblem import safe_fock_solver
144 changes: 144 additions & 0 deletions grad_dft/utils/eigenproblem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp
from jax import custom_vjp
from .types import Array, Scalar

# Probably don't alter these unless you know what you're doing
DEGEN_TOL = 1e-6
BROADENING = 1e-7


@custom_vjp
def safe_eigh(A: Array) -> tuple[Array, Array]:
r"""Get the eigenvalues and eigenvectors for an input real symmetric matrix.
A safe reverse mode gradient is implemented in safe_eigh_rev below.

Args:
A (Array): a 2D Jax array representing a real symmetric matrix.

Returns:
tuple[Array, Array]: the eigenvalues and eigenvectors of the input real symmetric matrix.
"""
evecs, evals = jnp.linalg.eigh(A)
return evecs, evals


def safe_eigh_fwd(A: Array) -> tuple[tuple[Array, Array], tuple[tuple[Array, Array], Array]]:
r"""Forward mode operation of safe_eigh. Saves evecs and evals for the reverse pass.

Args:
A (Array): a 2D Jax array representing a real symmetric matrix.

Returns:
tuple[tuple[Array, Array], tuple[tuple[Array, Array], Array]]: eigenvectors, eigenvalues and the input real symmetric matrix A.
"""
evecs, evals = safe_eigh(A)
return (evecs, evals), ((evecs, evals), A)


def safe_eigh_rev(res: tuple[tuple[Array, Array], Array], g: Array) -> tuple[Array]:
r"""Use the Lorentzian broading approach suggested in https://doi.org/10.1038/s42005-021-00568-6
to calculate stable backward mode gradients for degenerate eigenvectors. We only apply this
technique if eigenvalues are detected to be degenerate according to the constant DEGEN_TOL
in this module. When degeneracies are detected, the are broadened according to the constant
BROADENING also defined in this module.

Args:
res (tuple[tuple[Array, Array]): eigenvectors, eigenvales and the input real symmetric matrix A saved from the forward pass
g (Array): the gradients d[eigenvalues]/dA and d[eigenvectors]/dA

Returns:
tuple[Array]: the matrix of reverse mode gradients.

"""
(evals, evecs), A = res
grad_evals, grad_evecs = g
grad_evals_diag = jnp.diag(grad_evals)
evecs_trans = evecs.T

# Generate eigenvalue difference matrix
eval_diff = evals.reshape((1, -1)) - evals.reshape((-1, 1))
# Find elements where degen_tol condition was or wasn't was met
mask_degen = (jnp.abs(eval_diff) < DEGEN_TOL).astype(jnp.int32)
mask_non_degen = (jnp.abs(eval_diff) >= DEGEN_TOL).astype(jnp.int32)

# Regular gap for non_degen terms => 1/(e_j - e_i)
# Will get +infs turning to large numbers here if degeneracies are present.
# This doesn't matter as they multiply by 0 in the forthcoming mask when calculating
# the F-matrix
regular_gap = jnp.nan_to_num(jnp.divide(1, eval_diff))

# Lorentzian broadened gap for degen terms => (e_j - e_i)/((e_j - e_i)^2 + eps)
broadened_gap = eval_diff / (eval_diff * eval_diff + BROADENING)

# Calculate full F matrix. large numbers generated by NaNs from regular_gap are deleted here
F = 0.5 * (jnp.multiply(mask_non_degen, regular_gap) + jnp.multiply(mask_degen, broadened_gap))

# Set diagonals to 0
F = F.at[jnp.diag_indices_from(F)].set(0)

# Calculate the gradient
grad = (
jnp.linalg.inv(evecs_trans)
@ (0.5 * grad_evals_diag + jnp.multiply(F, evecs_trans @ grad_evecs))
@ evecs_trans
)
# Symmetrize
grad_sym = grad + grad.T
return (grad_sym,)


safe_eigh.defvjp(safe_eigh_fwd, safe_eigh_rev)


def safe_general_eigh(A: Array, B: Array) -> tuple[Array, Array]:
r"""Solve the general eigenproblem for the eigenvalues and eigenvectors. I.e,
. math::
AC = ECB
for matrix of eigenvectors C and diagonal matrix of eigenvalues E. This function requires all input
matrices to real and symmetric and the matrix B to be invertible.

Args:
A (Array): a real symmetric matrix
B (Array): another real symmetric matrix

Returns:
tuple[Array, Array]: the eigenvalues and matrix of eigenvectors
"""
L = jnp.linalg.cholesky(B)
L_inv = jnp.linalg.inv(L)
C = L_inv @ A @ L_inv.T
eigenvalues, eigenvectors_transformed = safe_eigh(C)
eigenvectors_original = L_inv.T @ eigenvectors_transformed
return eigenvalues, eigenvectors_original


def safe_fock_solver(fock: tuple[Array, Array], overlap: Array) -> tuple[Array, Array]:
"""Get the eigenenergies and molecular orbital coefficients for the
up and down fock spin matrices.
Args:
fock (tuple[Array, Array]): the up and down fock spin matrices
overlap (Array): the overlap matrix

Returns:
tuple[Array, Array]: the eigenenergies and matrix of molecular orbital coefficients.
"""
mo_energies_up, mo_coeffs_up = safe_general_eigh(fock[0], overlap)
mo_energies_dn, mo_coeffs_dn = safe_general_eigh(fock[1], overlap)
return jnp.stack((mo_energies_up, mo_energies_dn), axis=0), jnp.stack(
(mo_coeffs_up, mo_coeffs_dn), axis=0
)
Loading