Skip to content

Commit

Permalink
docstrings and black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jackbaker1001 committed Sep 8, 2023
1 parent c25a8bf commit f1cd50c
Showing 1 changed file with 83 additions and 20 deletions.
103 changes: 83 additions & 20 deletions grad_dft/utils/eigenproblem.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,61 +21,124 @@
DEGEN_TOL = 1e-6
BROADENING = 1e-7


@custom_vjp
def safe_eigh(A: Array):
def safe_eigh(A: Array) -> tuple[Array, Array]:
r"""Get the eigenvalues and eigenvectors for an input real symmetric matrix.
A safe reverse mode gradient is implemented in safe_eigh_rev below.
Args:
A (Array): a 2D Jax array representing a real symmetric matrix.
Returns:
tuple[Array, Array]: the eigenvalues and eigenvectors of the input real symmetric matrix.
"""
evecs, evals = jnp.linalg.eigh(A)
return evecs.real, evals.real
return evecs, evals


def safe_eigh_fwd(A: Array) -> tuple[tuple[Array, Array], tuple[tuple[Array, Array], Array]]:
r"""Forward mode operation of safe_eigh. Saves evecs and evals for the reverse pass.
def safe_eigh_fwd(A: Array):
Args:
A (Array): a 2D Jax array representing a real symmetric matrix.
Returns:
tuple[tuple[Array, Array], tuple[tuple[Array, Array], Array]]: eigenvectors, eigenvalues and the input real symmetric matrix A.
"""
evecs, evals = safe_eigh(A)
return (evecs, evals), ((evecs, evals), A)

def safe_eigh_rev(res: Array, g: Array):
"""Apply lorentzian broadening to degenerate eigenvalues/eignvectors

def safe_eigh_rev(res: tuple[tuple[Array, Array], Array], g: Array) -> tuple[Array]:
r"""Use the Lorentzian broading approach suggested in https://doi.org/10.1038/s42005-021-00568-6
to calculate stable backward mode gradients for degenerate eigenvectors. We only apply this
technique if eigenvalues are detected to be degenerate according to the constant DEGEN_TOL
in this module. When degeneracies are detected, the are broadened according to the constant
BROADENING also defined in this module.
Args:
res (tuple[tuple[Array, Array]): eigenvectors, eigenvales and the input real symmetric matrix A saved from the forward pass
g (Array): the gradients d[eigenvalues]/dA and d[eigenvectors]/dA
Returns:
tuple[Array]: the matrix of reverse mode gradients.
"""
(evals, evecs), A = res
grad_evals, grad_evecs = g
grad_evals_diag = jnp.diag(grad_evals)
evecs_trans = evecs.T

# Generate eigenvalue difference matrix
eval_diff = evals.reshape((1, -1)) - evals.reshape((-1, 1))
# Find elements where degen_tol condition was or wasn't was met
mask_degen = (jnp.abs(eval_diff) < DEGEN_TOL).astype(jnp.int32)
mask_non_degen = (jnp.abs(eval_diff) >= DEGEN_TOL).astype(jnp.int32)
mask_non_degen = (jnp.abs(eval_diff) >= DEGEN_TOL).astype(jnp.int32)

# Regular gap for non_degen terms => 1/(e_j - e_i)
# Will get +infs turning to large numbers here if degenerarcies are present.
# This doesn't matters as they multiply by 0 in the forthcoming mask when calculating
# Will get +infs turning to large numbers here if degeneracies are present.
# This doesn't matter as they multiply by 0 in the forthcoming mask when calculating
# the F-matrix
regular_gap = jnp.nan_to_num(jnp.divide(1, eval_diff))

# Lorentzian broadened gap for degen terms => (e_j - e_i)/((e_j - e_i)^2 + eps)
broadened_gap = eval_diff / (eval_diff*eval_diff + BROADENING)
# Calculate full F matrix. large numbers generated by NaNs from regular_gap are deleted here
F = 0.5*(jnp.multiply(mask_non_degen, regular_gap) + jnp.multiply(mask_degen, broadened_gap))
broadened_gap = eval_diff / (eval_diff * eval_diff + BROADENING)

# Calculate full F matrix. large numbers generated by NaNs from regular_gap are deleted here
F = 0.5 * (jnp.multiply(mask_non_degen, regular_gap) + jnp.multiply(mask_degen, broadened_gap))

# Set diagonals to 0
F = F.at[jnp.diag_indices_from(F)].set(0)

# Calculate the gradient
grad = jnp.linalg.inv(evecs_trans) @ (0.5*grad_evals_diag + jnp.multiply(F, evecs_trans @ grad_evecs)) @ evecs_trans
grad = (
jnp.linalg.inv(evecs_trans)
@ (0.5 * grad_evals_diag + jnp.multiply(F, evecs_trans @ grad_evecs))
@ evecs_trans
)
# Symmetrize
grad_sym = grad + grad.T
return grad_sym,
return (grad_sym,)


safe_eigh.defvjp(safe_eigh_fwd, safe_eigh_rev)

def safe_general_eigh(A: Array, B: Array):

def safe_general_eigh(A: Array, B: Array) -> tuple[Array, Array]:
r"""Solve the general eigenproblem for the eigenvalues and eigenvectors. I.e,
. math::
AC = ECB
for matrix of eigenvectors C and diagonal matrix of eigenvalues E. This function requires all input
matrices to real and symmetric and the matrix B to be invertible.
Args:
A (Array): a real symmetric matrix
B (Array): another real symmetric matrix
Returns:
tuple[Array, Array]: the eigenvalues and matrix of eigenvectors
"""
L = jnp.linalg.cholesky(B)
L_inv = jnp.linalg.inv(L)
C = L_inv @ A @ L_inv.T
eigenvalues, eigenvectors_transformed = safe_eigh(C)
eigenvectors_original = L_inv.T @ eigenvectors_transformed
return eigenvalues, eigenvectors_original

def safe_fock_solver(fock: Array, overlap: Array):

def safe_fock_solver(fock: tuple[Array, Array], overlap: Array) -> tuple[Array, Array]:
"""Get the eigenenergies and molecular orbital coefficients for the
up and down fock spin matrices.
Args:
fock (tuple[Array, Array]): the up and down fock spin matrices
overlap (Array): the overlap matrix
Returns:
tuple[Array, Array]: the eigenenergies and matrix of molecular orbital coefficients.
"""
mo_energies_up, mo_coeffs_up = safe_general_eigh(fock[0], overlap)
mo_energies_dn, mo_coeffs_dn = safe_general_eigh(fock[1], overlap)
return jnp.stack((mo_energies_up, mo_energies_dn), axis=0), jnp.stack((mo_coeffs_up, mo_coeffs_dn), axis=0)
return jnp.stack((mo_energies_up, mo_energies_dn), axis=0), jnp.stack(
(mo_coeffs_up, mo_coeffs_dn), axis=0
)

0 comments on commit f1cd50c

Please sign in to comment.