Skip to content

Commit

Permalink
Changes to allow solid computaiton in train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jackbaker1001 committed Nov 9, 2023
1 parent f4ca97b commit 290cb38
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions grad_dft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,16 @@ def predict(params: PyTree, atoms: Union[Molecule, Solid], *args) -> Tuple[Scala
(*batch_size, n_spin, n_orbitals, n_orbitals) for a `Molecule` or
(*batch_size, n_spin, n_kpt, n_orbitals, n_orbitals) for a `Solid`.
"""

energy, fock = energy_and_grads(params, atoms.rdm1, atoms, *args)

# Improve stability by clipping and symmetrizing
if isinstance(atoms, Molecule):
transpose_dims = (0, 2, 1)
elif isinstance(atoms, Solid):
transpose_dims = (0, 1, 3, 2)
fock = abs_clip(fock, clip_cte)
fock = 1 / 2 * (fock + fock.transpose(0, 2, 1))
fock = 1 / 2 * (fock + fock.transpose(transpose_dims).conj())
fock = abs_clip(fock, clip_cte)

# Compute the features that should be autodifferentiated
Expand Down Expand Up @@ -189,14 +195,15 @@ def predict(params: PyTree, atoms: Union[Molecule, Solid], *args) -> Tuple[Scala
vxc_expl = functional.densitygrads(
functional, params, atoms, nograd_densities, cinputs, grad_densities
)
fock += vxc_expl + vxc_expl.transpose(0, 2, 1) # Sum over omega
print(vxc_expl.shape)
fock += vxc_expl + vxc_expl.transpose(transpose_dims) # Sum over omega
fock = abs_clip(fock, clip_cte)

if functional.coefficient_input_grads:
vxc_expl = functional.coefficient_input_grads(
functional, params, atoms, nograd_cinputs, grad_cinputs, densities
)
fock += vxc_expl + vxc_expl.transpose(0, 2, 1) # Sum over omega
fock += vxc_expl + vxc_expl.transpose(transpose_dims) # Sum over omega
fock = abs_clip(fock, clip_cte)

fock = abs_clip(fock, clip_cte)
Expand Down

0 comments on commit 290cb38

Please sign in to comment.