diff --git a/grad_dft/train.py b/grad_dft/train.py index 15763fe..60c74e6 100644 --- a/grad_dft/train.py +++ b/grad_dft/train.py @@ -143,10 +143,16 @@ def predict(params: PyTree, atoms: Union[Molecule, Solid], *args) -> Tuple[Scala (*batch_size, n_spin, n_orbitals, n_orbitals) for a `Molecule` or (*batch_size, n_spin, n_kpt, n_orbitals, n_orbitals) for a `Solid`. """ - + energy, fock = energy_and_grads(params, atoms.rdm1, atoms, *args) + + # Improve stability by clipping and symmetrizing + if isinstance(atoms, Molecule): + transpose_dims = (0, 2, 1) + elif isinstance(atoms, Solid): + transpose_dims = (0, 1, 3, 2) fock = abs_clip(fock, clip_cte) - fock = 1 / 2 * (fock + fock.transpose(0, 2, 1)) + fock = 1 / 2 * (fock + fock.transpose(transpose_dims).conj()) fock = abs_clip(fock, clip_cte) # Compute the features that should be autodifferentiated @@ -189,14 +195,15 @@ def predict(params: PyTree, atoms: Union[Molecule, Solid], *args) -> Tuple[Scala vxc_expl = functional.densitygrads( functional, params, atoms, nograd_densities, cinputs, grad_densities ) - fock += vxc_expl + vxc_expl.transpose(0, 2, 1) # Sum over omega + print(vxc_expl.shape) + fock += vxc_expl + vxc_expl.transpose(transpose_dims) # Sum over omega fock = abs_clip(fock, clip_cte) if functional.coefficient_input_grads: vxc_expl = functional.coefficient_input_grads( functional, params, atoms, nograd_cinputs, grad_cinputs, densities ) - fock += vxc_expl + vxc_expl.transpose(0, 2, 1) # Sum over omega + fock += vxc_expl + vxc_expl.transpose(transpose_dims) # Sum over omega fock = abs_clip(fock, clip_cte) fock = abs_clip(fock, clip_cte)