Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[6] Setup CI test pipeline #7

Merged
merged 8 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Build and Test

on:
push:
branches:
- main
pull_request:
branches:
- main
- develop

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies and Build package
run: |
python -m pip install --upgrade pip setuptools wheel
pip install .["git_ci"]

- name: Run tests
run: |
pytest -n=auto -x
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ include-package-data = true
package-data = {"*" = ["AUTHORS.txt", "VERSION.txt"]}

[tool.setuptools.dynamic]
version = {file = "robust_hermite_ft/VERSION.txt"}
version = {file = "src/robust_hermite_ft/VERSION.txt"}
readme = {file = ["README.rst"]}
dependencies = {file = "requirements/base.txt"}
optional-dependencies = {fast = {file = "requirements/fast.txt"}, dev = {file = "requirements/dev.txt"}, examples = {file = "requirements/examples.txt"}}
optional-dependencies = {fast = {file = "requirements/fast.txt"}, dev = {file = "requirements/dev.txt"}, examples = {file = "requirements/examples.txt"}, git_ci = {file = "requirements/git_ci.txt"}}

[tool.isort]
profile = "black"
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
norecursedirs = tests/reference_files
9 changes: 9 additions & 0 deletions requirements/git_ci.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
black
coverage
cython>=3.0.10
cython-lint>=0.16.0
numba>=0.55.0
pytest
pytest-cov
pytest-xdist
ruff
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import Cython.Compiler.Options
import numpy as np
from Cython.Build import cythonize
from setuptools import Extension, setup
from setuptools import Extension, find_packages, setup

# === Constants ===

SOURCES = [
"robust_hermite_ft/hermite_functions/_c_hermite.pyx",
"src/robust_hermite_ft/hermite_functions/_c_hermite.pyx",
]

# === Setup ===
Expand All @@ -41,8 +41,10 @@
]

setup(
package_dir={"": "src"},
packages=find_packages("src"),
ext_modules=cythonize(CY_MODULES, nthreads=1, annotate=True),
package_data={"robust_hermite_ft": ["*.pxd"]}, # include pxd files
include_package_data=False, # ignore other files
include_package_data=True,
zip_safe=False,
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
271 changes: 141 additions & 130 deletions tests/reference_files/generate_hermfunc_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,9 @@
import json
import os
from dataclasses import asdict, dataclass, field
from functools import partial
from multiprocessing import Pool
from time import perf_counter
from typing import Dict, Tuple
from typing import Dict, List, Tuple

import numpy as np
from sympy import Symbol as sp_Symbol
from sympy import exp as sp_exp
from sympy import pi as sp_pi
from sympy import sqrt as sp_sqrt
from sympy import symbols as sp_symbols
from tqdm import tqdm

# === Constants ===

Expand Down Expand Up @@ -61,7 +52,7 @@ class HermiteFunctionsParameters:

n: int
alpha: float
ns_for_single_function: list[int] = field(default_factory=list)
ns_for_single_function: List[int] = field(default_factory=list)


@dataclass
Expand All @@ -86,132 +77,152 @@ class ReferenceHermiteFunctionsMetadata:
x_values: np.ndarray


# === Functions ===
# === Main code ===

if __name__ == "__main__":

def _eval_sym_hermite_worker(
row_index: int,
x: np.ndarray,
x_sym: sp_Symbol,
n: int,
alpha: float,
expressions: np.ndarray,
num_digits: int,
) -> Tuple[int, np.ndarray]:
"""
Worker function to evaluate the Hermite functions at the given points ``x``.

"""

# the Hermite functions are evaluated at the given points
hermite_function_values = np.empty(shape=n + 1, dtype=np.float64)

# the Hermite functions are evaluated using the recurrence relation
for iter_j in range(0, n + 1):
# the expression for the Hermite function is evaluated
hermite_expression = expressions[iter_j]
hermite_function_values[iter_j] = hermite_expression.subs(
x_sym, x[row_index] / alpha
).evalf(n=num_digits)

return row_index, hermite_function_values


def _eval_sym_dilated_hermite_function_basis(
x: np.ndarray,
n: int,
alpha: float,
num_digits: int = 16,
) -> np.ndarray:
"""
Evaluates the first ``n + 1`` dilated Hermite functions at the given points ``x``.
They are defined as

.. image:: docs/hermite_functions/equations/DilatedHermiteFunctions.png

Parameters
----------
x : :class:`np.ndarray` of shape (m,)
The points at which the Hermite functions are evaluated.
n : :class:`int`
The order of the Hermite functions.
alpha : :class:`float`
The scaling factor of the independent variable ``x``.
num_digits : :class:`int`, default=16
The number of digits used in the symbolic evaluation of the Hermite functions.
For orders ``n >= 50`` and high ``x / alpha``-values, the symbolic evaluation
might be inaccurate. In this case, going to quadruple precision
(``n_digits~=32``) or higher might be necessary.

Returns
-------
hermite_function_basis : :class:`np.ndarray` of shape (m, n + 1)
The values of the first ``n + 1`` dilated Hermite functions evaluated at the
points ``x``.

"""

# the Hermite functions are evaluated using their recurrence relation given by
# h_{n+1}(x) = sqrt(2 / (n + 1)) * x * h_{n}(x) - sqrt(n / (n + 1)) * h_{n-1}(x)
# with the initial conditions h_{-1}(x) = 0 and
# h_{0}(x) = pi**(-1/4) * exp(-x**2 / 2)
x_sym = sp_symbols("x")
hermite_expressions = np.empty(shape=(n + 1), dtype=object)

# the first two Hermite function expressions are defined with the involved Gaussian
# function not multiplied in yet to avoid the build-up of large expressions
h_i_minus_1 = 0
h_i = sp_exp(-(x_sym**2) / 2) / sp_sqrt(sp_sqrt(sp_pi)) # type: ignore
hermite_expressions[0] = h_i

# the Hermite functions are evaluated using the recurrence relation
for iter_j in tqdm(range(0, n), desc="Generating Hermite expressions", leave=False):
h_i_plus_1 = (
sp_sqrt(2 / (iter_j + 1)) * x_sym * h_i
- sp_sqrt(iter_j / (iter_j + 1)) * h_i_minus_1 # type: ignore
)
h_i_minus_1, h_i = h_i, h_i_plus_1
hermite_expressions[iter_j + 1] = h_i

# the Hermite functions are evaluated at the given points
hermite_function_basis = np.empty(shape=(x.size, n + 1), dtype=np.float64)

# the evaluation is done in parallel to speed up the process but a progress bar is
# used to keep track of the progress
with Pool() as pool:
worker = partial(
_eval_sym_hermite_worker,
x=x,
x_sym=x_sym,
n=n,
alpha=alpha,
expressions=hermite_expressions,
num_digits=num_digits,
)
results = list(
tqdm(
pool.imap(worker, range(0, x.size)),
total=x.size,
desc="Evaluating Hermite functions",
leave=False,
# === Imports ===

from functools import partial
from multiprocessing import Pool
from time import perf_counter

from sympy import Symbol as sp_Symbol
from sympy import exp as sp_exp
from sympy import pi as sp_pi
from sympy import sqrt as sp_sqrt
from sympy import symbols as sp_symbols
from tqdm import tqdm

# === Functions ===

def _eval_sym_hermite_worker(
row_index: int,
x: np.ndarray,
x_sym: sp_Symbol,
n: int,
alpha: float,
expressions: np.ndarray,
num_digits: int,
) -> Tuple[int, np.ndarray]:
"""
Worker function to evaluate the Hermite functions at the given points ``x``.

"""

# the Hermite functions are evaluated at the given points
hermite_function_values = np.empty(shape=n + 1, dtype=np.float64)

# the Hermite functions are evaluated using the recurrence relation
for iter_j in range(0, n + 1):
# the expression for the Hermite function is evaluated
hermite_expression = expressions[iter_j]
hermite_function_values[iter_j] = hermite_expression.subs(
x_sym, x[row_index] / alpha
).evalf(n=num_digits)

return row_index, hermite_function_values

def _eval_sym_dilated_hermite_function_basis(
x: np.ndarray,
n: int,
alpha: float,
num_digits: int = 16,
) -> np.ndarray:
"""
Evaluates the first ``n + 1`` dilated Hermite functions at the given points
``x``.
They are defined as

.. image:: docs/hermite_functions/equations/DilatedHermiteFunctions.png

Parameters
----------
x : :class:`np.ndarray` of shape (m,)
The points at which the Hermite functions are evaluated.
n : :class:`int`
The order of the Hermite functions.
alpha : :class:`float`
The scaling factor of the independent variable ``x``.
num_digits : :class:`int`, default=16
The number of digits used in the symbolic evaluation of the Hermite
functions.
For orders ``n >= 50`` and high ``x / alpha``-values, the symbolic
evaluation might be inaccurate. In this case, going to quadruple precision
(``n_digits~=32``) or higher might be necessary.

Returns
-------
hermite_function_basis : :class:`np.ndarray` of shape (m, n + 1)
The values of the first ``n + 1`` dilated Hermite functions evaluated at the
points ``x``.

"""

# the Hermite functions are evaluated using their recurrence relation given by
# h_{n+1}(x) = sqrt(2 / (n + 1)) * x * h_{n}(x) - sqrt(n / (n + 1)) * h_{n-1}(x)
# with the initial conditions h_{-1}(x) = 0 and
# h_{0}(x) = pi**(-1/4) * exp(-x**2 / 2)
x_sym = sp_symbols("x")
hermite_expressions = np.empty(shape=(n + 1), dtype=object)

# the first two Hermite function expressions are defined with the involved
# Gaussian function not multiplied in yet to avoid the build-up of large
# expressions
h_i_minus_1 = 0
h_i = sp_exp(-(x_sym**2) / 2) / sp_sqrt(sp_sqrt(sp_pi)) # type: ignore
hermite_expressions[0] = h_i

# the Hermite functions are evaluated using the recurrence relation
for iter_j in tqdm(
range(0, n),
desc="Generating Hermite expressions",
leave=False,
):
h_i_plus_1 = (
sp_sqrt(2 / (iter_j + 1)) * x_sym * h_i
- sp_sqrt(iter_j / (iter_j + 1)) * h_i_minus_1 # type: ignore
)
h_i_minus_1, h_i = h_i, h_i_plus_1
hermite_expressions[iter_j + 1] = h_i

# the Hermite functions are evaluated at the given points
hermite_function_basis = np.empty(shape=(x.size, n + 1), dtype=np.float64)

# the evaluation is done in parallel to speed up the process but a progress bar
# is used to keep track of the progress
with Pool() as pool:
worker = partial(
_eval_sym_hermite_worker,
x=x,
x_sym=x_sym,
n=n,
alpha=alpha,
expressions=hermite_expressions,
num_digits=num_digits,
)
results = list(
tqdm(
pool.imap(worker, range(0, x.size)),
total=x.size,
desc="Evaluating Hermite functions",
leave=False,
)
)
)

# the results are stored in the matrix
for row_idx, row_values in results:
hermite_function_basis[row_idx, ::] = row_values

return hermite_function_basis / np.sqrt(alpha)

# the results are stored in the matrix
for row_idx, row_values in results:
hermite_function_basis[row_idx, ::] = row_values

# === Main ===
return hermite_function_basis / np.sqrt(alpha)

# this part generates NumPy binary files for the first 250 dilated Hermite functions
# with different scaling factors evaluated at high precision for a series of 501 points
# in the range [-45, 45]
# NOTE: it is important that the number of points is odd to have a point at exactly 0
# === Test file generation ===

if __name__ == "__main__":
# this part generates NumPy binary files for the first 250 dilated Hermite functions
# with different scaling factors evaluated at high precision for a series of 501
# points in the range [-45, 45]
# NOTE: it is important that the number of points is odd to have a point at
# exactly 0

# --- Setup ---

Expand Down
Loading