-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #80 from BoothGroup/numpy_switch
Switchable numpy backends
- Loading branch information
Showing
128 changed files
with
13,863 additions
and
12,777 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
"""Backend for NumPy operations. | ||
Notes: | ||
Currently, the following backends are supported: | ||
- NumPy | ||
- CuPy | ||
- TensorFlow | ||
- JAX | ||
- CTF (Cyclops Tensor Framework) | ||
Non-NumPy backends are only lightly supported. Some functionality may not be available, and only | ||
minimal tests are performed. Some operations that require interaction with NumPy such as the | ||
PySCF interfaces may not be efficient, due to the need to convert between NumPy and the backend | ||
array types. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import importlib | ||
from typing import TYPE_CHECKING | ||
|
||
from ebcc import BACKEND | ||
|
||
if TYPE_CHECKING: | ||
from types import ModuleType | ||
from typing import Union, TypeVar, Optional | ||
|
||
from numpy import int64, generic | ||
from numpy.typing import NDArray | ||
|
||
T = TypeVar("T", bound=generic) | ||
|
||
if BACKEND == "numpy": | ||
import numpy as np | ||
elif BACKEND == "cupy": | ||
import cupy as np # type: ignore[no-redef] | ||
elif BACKEND == "tensorflow": | ||
import tensorflow as tf | ||
import tensorflow.experimental.numpy as np # type: ignore[no-redef] | ||
elif BACKEND == "jax": | ||
import jax | ||
import jax.numpy as np # type: ignore[no-redef] | ||
elif BACKEND in ("ctf", "cyclops"): | ||
import ctf | ||
|
||
|
||
def __getattr__(name: str) -> ModuleType: | ||
"""Get the backend module.""" | ||
return importlib.import_module(f"ebcc.backend._{BACKEND.lower()}") | ||
|
||
|
||
def ensure_scalar(obj: Union[T, NDArray[T]]) -> T: | ||
"""Ensure that an object is a scalar. | ||
Args: | ||
obj: Object to ensure is a scalar. | ||
Returns: | ||
Scalar object. | ||
""" | ||
if BACKEND in ("numpy", "cupy", "jax"): | ||
return np.asarray(obj).item() # type: ignore | ||
elif BACKEND == "tensorflow": | ||
if isinstance(obj, tf.Tensor): | ||
return obj.numpy().item() # type: ignore | ||
return obj # type: ignore | ||
elif BACKEND in ("ctf", "cyclops"): | ||
if isinstance(obj, ctf.tensor): | ||
return obj.to_nparray().item() # type: ignore | ||
return obj # type: ignore | ||
else: | ||
raise NotImplementedError(f"`ensure_scalar` not implemented for backend {BACKEND}.") | ||
|
||
|
||
def to_numpy(array: NDArray[T], dtype: Optional[type[generic]] = None) -> NDArray[T]: | ||
"""Convert an array to NumPy. | ||
Args: | ||
array: Array to convert. | ||
dtype: Data type to convert to. | ||
Returns: | ||
Array in NumPy format. | ||
Notes: | ||
This function does not guarantee a copy of the array. | ||
""" | ||
if BACKEND == "numpy": | ||
ndarray = array | ||
elif BACKEND == "cupy": | ||
ndarray = np.asnumpy(array) # type: ignore | ||
elif BACKEND == "jax": | ||
ndarray = np.array(array) # type: ignore | ||
elif BACKEND == "tensorflow": | ||
ndarray = array.numpy() # type: ignore | ||
elif BACKEND in ("ctf", "cyclops"): | ||
ndarray = array.to_nparray() # type: ignore | ||
else: | ||
raise NotImplementedError(f"`to_numpy` not implemented for backend {BACKEND}.") | ||
if dtype is not None and ndarray.dtype != dtype: | ||
ndarray = ndarray.astype(dtype) | ||
return ndarray | ||
|
||
|
||
def _put( | ||
array: NDArray[T], | ||
indices: Union[NDArray[int64], tuple[NDArray[int64], ...]], | ||
values: NDArray[T], | ||
) -> NDArray[T]: | ||
"""Put values into an array at specified indices. | ||
Args: | ||
array: Array to put values into. | ||
indices: Indices to put values at. | ||
values: Values to put into the array. | ||
Returns: | ||
Array with values put at specified indices. | ||
Notes: | ||
This function does not guarantee a copy of the array. | ||
""" | ||
if BACKEND == "numpy" or BACKEND == "cupy": | ||
if isinstance(indices, tuple): | ||
indices_flat = np.ravel_multi_index(indices, array.shape) | ||
np.put(array, indices_flat, values) | ||
else: | ||
np.put(array, indices, values) | ||
return array | ||
elif BACKEND == "jax": | ||
if isinstance(indices, tuple): | ||
indices_flat = np.ravel_multi_index(indices, array.shape) | ||
array = np.put(array, indices_flat, values, inplace=False) # type: ignore | ||
else: | ||
array = np.put(array, indices, values, inplace=False) # type: ignore | ||
return array | ||
elif BACKEND == "tensorflow": | ||
if isinstance(indices, (tuple, list)): | ||
indices_grid = tf.meshgrid(*indices, indexing="ij") | ||
indices = tf.stack([np.ravel(tf.cast(idx, tf.int32)) for idx in indices_grid], axis=1) | ||
else: | ||
indices = tf.cast(tf.convert_to_tensor(indices), tf.int32) | ||
indices = tf.expand_dims(indices, axis=-1) | ||
values = np.ravel(tf.convert_to_tensor(values, dtype=array.dtype)) | ||
return tf.tensor_scatter_nd_update(array, indices, values) # type: ignore | ||
elif BACKEND in ("ctf", "cyclops"): | ||
# TODO MPI has to be manually managed here | ||
if isinstance(indices, tuple): | ||
indices_flat = np.ravel_multi_index(indices, array.shape) | ||
array.write(indices_flat, values) # type: ignore | ||
else: | ||
array.write(indices, values) # type: ignore | ||
return array | ||
else: | ||
raise NotImplementedError(f"`_put` not implemented for backend {BACKEND}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# type: ignore | ||
"""Cyclops Tensor Framework backend.""" | ||
|
||
import ctf | ||
import numpy | ||
import opt_einsum | ||
|
||
|
||
def __getattr__(name): | ||
"""Get the attribute from CTF.""" | ||
return getattr(ctf, name) | ||
|
||
|
||
class FakeLinalg: | ||
"""Fake linalg module for CTF.""" | ||
|
||
def __getattr__(self, name): | ||
"""Get the attribute from CTF's linalg module.""" | ||
return getattr(ctf.linalg, name) | ||
|
||
def eigh(self, a): # noqa: D102 | ||
# TODO Need to determine if SCALAPACK is available | ||
w, v = numpy.linalg.eigh(a.to_nparray()) | ||
w = ctf.astensor(w) | ||
v = ctf.astensor(v) | ||
return w, v | ||
|
||
def norm(self, a, ord=None): # noqa: D102 | ||
return ctf.norm(a, ord=ord) | ||
|
||
|
||
linalg = FakeLinalg() | ||
|
||
|
||
bool_ = numpy.bool_ | ||
inf = numpy.inf | ||
asarray = ctf.astensor | ||
|
||
|
||
_array = ctf.array | ||
|
||
|
||
def array(obj, **kwargs): # noqa: D103 | ||
if isinstance(obj, ctf.tensor): | ||
return obj | ||
return _array(numpy.asarray(obj), **kwargs) | ||
|
||
|
||
def astype(obj, dtype): # noqa: D103 | ||
return obj.astype(dtype) | ||
|
||
|
||
def zeros_like(obj): # noqa: D103 | ||
return ctf.zeros(obj.shape).astype(obj.dtype) | ||
|
||
|
||
def ones_like(obj): # noqa: D103 | ||
return ctf.ones(obj.shape).astype(obj.dtype) | ||
|
||
|
||
def arange(start, stop=None, step=1, dtype=None): # noqa: D103 | ||
if stop is None: | ||
stop = start | ||
start = 0 | ||
return ctf.arange(start, stop, step=step, dtype=dtype) | ||
|
||
|
||
def argmin(obj): # noqa: D103 | ||
return ctf.to_nparray(obj).argmin() | ||
|
||
|
||
def argmax(obj): # noqa: D103 | ||
return ctf.to_nparray(obj).argmax() | ||
|
||
|
||
def bitwise_and(a, b): # noqa: D103 | ||
return a * b | ||
|
||
|
||
def bitwise_not(a): # noqa: D103 | ||
return ones_like(a) - a | ||
|
||
|
||
def concatenate(arrays, axis=None): # noqa: D103 | ||
if axis is None: | ||
axis = 0 | ||
if axis < 0: | ||
axis += arrays[0].ndim | ||
shape = list(arrays[0].shape) | ||
for arr in arrays[1:]: | ||
for i, (a, b) in enumerate(zip(shape, arr.shape)): | ||
if i == axis: | ||
shape[i] += b | ||
elif a != b: | ||
raise ValueError("All arrays must have the same shape") | ||
|
||
result = ctf.zeros(shape, dtype=arrays[0].dtype) | ||
start = 0 | ||
for arr in arrays: | ||
end = start + arr.shape[axis] | ||
slices = [slice(None)] * result.ndim | ||
slices[axis] = slice(start, end) | ||
result[tuple(slices)] = arr | ||
start = end | ||
|
||
return result | ||
|
||
|
||
def _block_recursive(arrays, max_depth, depth=0): # noqa: D103 | ||
if depth < max_depth: | ||
arrs = [_block_recursive(arr, max_depth, depth + 1) for arr in arrays] | ||
return concatenate(arrs, axis=-(max_depth - depth)) | ||
else: | ||
return arrays | ||
|
||
|
||
def block(arrays): # noqa: D103 | ||
def _get_max_depth(arrays): | ||
if isinstance(arrays, list): | ||
return 1 + max([_get_max_depth(arr) for arr in arrays]) | ||
return 0 | ||
|
||
return _block_recursive(arrays, _get_max_depth(arrays)) | ||
|
||
|
||
def einsum(*args, optimize=True, **kwargs): | ||
"""Evaluate an einsum expression.""" | ||
# FIXME This shouldn't be called, except via `util.einsum`, which should have already | ||
# optimised the expression. We should check if this contraction has more than | ||
# two tensors and if so, raise an error. | ||
return ctf.einsum(*args, **kwargs) | ||
|
||
|
||
def einsum_path(*args, **kwargs): | ||
"""Evaluate the lowest cost contraction order for an einsum expression.""" | ||
kwargs = dict(kwargs) | ||
if kwargs.get("optimize", True) is True: | ||
kwargs["optimize"] = "optimal" | ||
return opt_einsum.contract_path(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# type: ignore | ||
"""CuPy backend.""" | ||
|
||
import cupy | ||
import opt_einsum | ||
|
||
|
||
def __getattr__(name): | ||
"""Get the attribute from CuPy.""" | ||
return getattr(cupy, name) | ||
|
||
|
||
def einsum_path(*args, **kwargs): | ||
"""Evaluate the lowest cost contraction order for an einsum expression.""" | ||
kwargs = dict(kwargs) | ||
if kwargs.get("optimize", True) is True: | ||
kwargs["optimize"] = "optimal" | ||
return opt_einsum.contract_path(*args, **kwargs) |
Oops, something went wrong.