Skip to content

Commit

Permalink
Update scripts for array API changes
Browse files Browse the repository at this point in the history
  • Loading branch information
obackhouse committed Sep 21, 2024
1 parent 05703b1 commit 7802aec
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 22 deletions.
4 changes: 2 additions & 2 deletions ebcc/codegen/bootstrap_CCSD.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def tensor_cleanup(self, *args):
preamble += "\nr1new = Namespace()\nr2new = Namespace()"
kwargs = {
"preamble": preamble,
"postamble": "r2new.baba = r2new.abab.transpose(1, 0, 3, 2)" if spin == "uhf" else None, # FIXME
"postamble": "r2new.baba = np.transpose(r2new.abab, (1, 0, 3, 2))" if spin == "uhf" else None, # FIXME
"as_dict": True,
}
else:
Expand Down Expand Up @@ -748,7 +748,7 @@ def tensor_cleanup(self, *args):
preamble += "\nr1new = Namespace()\nr2new = Namespace()"
kwargs = {
"preamble": preamble,
"postamble": "r2new.baba = r2new.abab.transpose(1, 0, 3, 2)" if spin == "uhf" else None, # FIXME
"postamble": "r2new.baba = np.transpose(r2new.abab, (1, 0, 3, 2))" if spin == "uhf" else None, # FIXME
"as_dict": True,
}
else:
Expand Down
4 changes: 2 additions & 2 deletions ebcc/codegen/bootstrap_DFCCSD.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@
preamble += "\nr1new = Namespace()\nr2new = Namespace()"
kwargs = {
"preamble": preamble,
"postamble": "r2new.baba = r2new.abab.transpose(1, 0, 3, 2)" if spin == "uhf" else None, # FIXME
"postamble": "r2new.baba = np.transpose(r2new.abab, (1, 0, 3, 2))" if spin == "uhf" else None, # FIXME
"as_dict": True,
}
else:
Expand Down Expand Up @@ -756,7 +756,7 @@
preamble += "\nr1new = Namespace()\nr2new = Namespace()"
kwargs = {
"preamble": preamble,
"postamble": "r2new.baba = r2new.abab.transpose(1, 0, 3, 2)" if spin == "uhf" else None, # FIXME
"postamble": "r2new.baba = np.transpose(r2new.abab, (1, 0, 3, 2))" if spin == "uhf" else None, # FIXME
"as_dict": True,
}
else:
Expand Down
2 changes: 1 addition & 1 deletion ebcc/codegen/bootstrap_MPn.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def get_preamble(n, spin, name="rdm{n}"):
preamble += "\nr1new = Namespace()\nr2new = Namespace()"
kwargs = {
"preamble": preamble,
"postamble": "r2new.baba = r2new.abab.transpose(1, 0, 3, 2)" if spin == "uhf" else None, # FIXME
"postamble": "r2new.baba = np.transpose(r2new.abab, (1, 0, 3, 2))" if spin == "uhf" else None, # FIXME
"as_dict": True,
}
else:
Expand Down
34 changes: 17 additions & 17 deletions ebcc/codegen/bootstrap_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
self,
einsum_func="einsum",
einsum_kwargs=None,
transpose_func="{arg}.transpose({transpose})",
transpose_func="np.transpose({arg}, {transpose})",
name_generator=None,
spin="ghf",
**kwargs,
Expand Down Expand Up @@ -487,13 +487,13 @@ def get_density_einsum_postamble(n, spin, name="rdm{n}", spaces=None):
f"{name}.bbbb.{perm}" for perm in spaces
)
postamble += f"\n{name} = Namespace("
postamble += f"\n aaaa={name}.aaaa.swapaxes(1, 2),"
postamble += f"\n aabb={name}.abab.swapaxes(1, 2),"
postamble += f"\n bbbb={name}.bbbb.swapaxes(1, 2),"
postamble += f"\n aaaa=np.transpose({name}.aaaa, (0, 2, 1, 3)),"
postamble += f"\n aabb=np.transpose({name}.abab, (0, 2, 1, 3)),"
postamble += f"\n bbbb=np.transpose({name}.bbbb, (0, 2, 1, 3)),"
postamble += f"\n)"
else:
postamble = f"{name} = pack_2e(%s)" % ", ".join(f"{name}.{perm}" for perm in spaces)
postamble += f"\n{name} = {name}.swapaxes(1, 2)"
postamble += f"\n{name} = np.transpose({name}, (0, 2, 1, 3))"
return postamble


Expand All @@ -502,24 +502,24 @@ def get_boson_einsum_preamble(spin):
if spin == "uhf":
preamble = "gc = Namespace("
preamble += "\n aa=Namespace("
preamble += "\n boo=g.aa.boo.transpose(0, 2, 1),"
preamble += "\n bov=g.aa.bvo.transpose(0, 2, 1),"
preamble += "\n bvo=g.aa.bov.transpose(0, 2, 1),"
preamble += "\n bvv=g.aa.bvv.transpose(0, 2, 1),"
preamble += "\n boo=np.transpose(g.aa.boo, (0, 2, 1)),"
preamble += "\n bov=np.transpose(g.aa.bvo, (0, 2, 1)),"
preamble += "\n bvo=np.transpose(g.aa.bov, (0, 2, 1)),"
preamble += "\n bvv=np.transpose(g.aa.bvv, (0, 2, 1)),"
preamble += "\n ),"
preamble += "\n bb=Namespace("
preamble += "\n boo=g.bb.boo.transpose(0, 2, 1),"
preamble += "\n bov=g.bb.bvo.transpose(0, 2, 1),"
preamble += "\n bvo=g.bb.bov.transpose(0, 2, 1),"
preamble += "\n bvv=g.bb.bvv.transpose(0, 2, 1),"
preamble += "\n boo=np.transpose(g.bb.boo, (0, 2, 1)),"
preamble += "\n bov=np.transpose(g.bb.bvo, (0, 2, 1)),"
preamble += "\n bvo=np.transpose(g.bb.bov, (0, 2, 1)),"
preamble += "\n bvv=np.transpose(g.bb.bvv, (0, 2, 1)),"
preamble += "\n ),"
preamble += "\n)"
else:
preamble = "gc = Namespace("
preamble += "\n boo=g.boo.transpose(0, 2, 1),"
preamble += "\n bov=g.bvo.transpose(0, 2, 1),"
preamble += "\n bvo=g.bov.transpose(0, 2, 1),"
preamble += "\n bvv=g.bvv.transpose(0, 2, 1),"
preamble += "\n boo=np.transpose(g.boo, (0, 2, 1)),"
preamble += "\n bov=np.transpose(g.bvo, (0, 2, 1)),"
preamble += "\n bvo=np.transpose(g.bov, (0, 2, 1)),"
preamble += "\n bvv=np.transpose(g.bvv, (0, 2, 1)),"
preamble += "\n)"
return preamble

Expand Down

0 comments on commit 7802aec

Please sign in to comment.