From 487ef3d8ac0a5ba9315d1a75344d793a1cca57a4 Mon Sep 17 00:00:00 2001 From: Jack Baker Date: Fri, 8 Sep 2023 18:13:28 -0400 Subject: [PATCH] Removed old eigh functionality --- grad_dft/molecule.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/grad_dft/molecule.py b/grad_dft/molecule.py index 8261aec..f9ad526 100644 --- a/grad_dft/molecule.py +++ b/grad_dft/molecule.py @@ -16,7 +16,6 @@ from dataclasses import fields from grad_dft.utils import Array, Scalar, PyTree, vmap_chunked from functools import partial -from grad_dft.utils.eigenproblem import safe_general_eigh from jax import numpy as jnp from jax import scipy as jsp @@ -565,27 +564,9 @@ def chunked_jvp(chi_tensor, gr_tensor, ao_tensor): return (jax.jit(chunked_jvp)(chi.transpose(3, 0, 1, 2), gr, ao)).transpose(1, 2, 3, 0) - -def eig(h, x): - e0, c0 = eigh2d(h[0], x) - e1, c1 = eigh2d(h[1], x) - return jnp.stack((e0, e1), axis=0), jnp.stack((c0, c1), axis=0) - def abs_clip(arr, threshold): return jnp.where(jnp.abs(arr) > threshold, arr, 0) -def general_eigh(A, B): - L = jnp.linalg.cholesky(B) - L_inv = jnp.linalg.inv(L) - C = L_inv @ A @ L_inv.T - C = abs_clip(C, 1e-20) - eigenvalues, eigenvectors_transformed = jnp.linalg.eigh(C) - # eigenvalues, eigenvectors_transformed = jsp.linalg.eigh(C) - eigenvectors_original = L_inv.T @ eigenvectors_transformed - eigenvectors_original = abs_clip(eigenvectors_original, 1e-20) - eigenvalues = abs_clip(eigenvalues, 1e-20) - return eigenvalues, eigenvectors_original - ######################################################################