From f84bb93ba996bede1f3da81447a46ff7b31f9193 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Fri, 9 Feb 2024 15:30:52 -0600 Subject: [PATCH] jax.lax.logistic-->jax.nn.sigmoid --- diffmah/individual_halo_assembly.py | 9 ++------- diffmah/rockstar_pdf_model.py | 9 ++------- diffmah/tests/test_utils.py | 9 +++++---- diffmah/tng_pdf_model.py | 9 ++------- diffmah/utils.py | 11 +++++++---- 5 files changed, 18 insertions(+), 29 deletions(-) diff --git a/diffmah/individual_halo_assembly.py b/diffmah/individual_halo_assembly.py index 5405785..2687091 100644 --- a/diffmah/individual_halo_assembly.py +++ b/diffmah/individual_halo_assembly.py @@ -1,4 +1,5 @@ """Model for individual halo mass assembly based on a power-law with rolling index.""" + from jax import grad from jax import jit as jjit from jax import lax @@ -6,7 +7,7 @@ from jax import vmap from .defaults import LGT0, MAH_K -from .utils import get_1d_arrays +from .utils import _sigmoid, get_1d_arrays @jjit @@ -137,12 +138,6 @@ def _softplus(x): return jnp.log(1 + lax.exp(x)) -@jjit -def _sigmoid(x, logtc, k, ymin, ymax): - height_diff = ymax - ymin - return ymin + height_diff / (1.0 + lax.exp(-k * (x - logtc))) - - @jjit def _get_early_late(ue, ul): late = _softplus(ul) diff --git a/diffmah/rockstar_pdf_model.py b/diffmah/rockstar_pdf_model.py index 3054c51..15eb849 100644 --- a/diffmah/rockstar_pdf_model.py +++ b/diffmah/rockstar_pdf_model.py @@ -1,4 +1,5 @@ """Model of halo population assembly calibrated to Rockstar halos.""" + from collections import OrderedDict from jax import jit as jjit @@ -7,7 +8,7 @@ from jax import vmap from .defaults import MAH_K -from .utils import get_cholesky_from_params +from .utils import _sigmoid, get_cholesky_from_params TODAY = 13.8 LGT0 = jnp.log10(TODAY) @@ -56,12 +57,6 @@ ) -@jjit -def _sigmoid(x, logtc, k, ymin, ymax): - height_diff = ymax - ymin - return ymin + height_diff / (1.0 + lax.exp(-k * (x - logtc))) - - def _get_cov_scalar( log10_lge_lge, log10_lgl_lgl, diff --git a/diffmah/tests/test_utils.py b/diffmah/tests/test_utils.py index ce5e905..13fbf51 100644 --- a/diffmah/tests/test_utils.py +++ b/diffmah/tests/test_utils.py @@ -1,15 +1,16 @@ """ """ + import numpy as np from jax import jit as jax_jit from jax import numpy as jax_np from jax import value_and_grad from ..utils import ( + _inverse_sigmoid, + _sigmoid, get_cholesky_from_params, jax_adam_wrapper, - jax_inverse_sigmoid, - jax_sigmoid, ) @@ -17,8 +18,8 @@ def test_inverse_sigmoid_actually_inverts(): """""" x0, k, ylo, yhi = 0, 5, 1, 0 xarr = np.linspace(-1, 1, 100) - yarr = np.array(jax_sigmoid(xarr, x0, k, ylo, yhi)) - xarr2 = np.array(jax_inverse_sigmoid(yarr, x0, k, ylo, yhi)) + yarr = np.array(_sigmoid(xarr, x0, k, ylo, yhi)) + xarr2 = np.array(_inverse_sigmoid(yarr, x0, k, ylo, yhi)) assert np.allclose(xarr, xarr2, rtol=1e-3) diff --git a/diffmah/tng_pdf_model.py b/diffmah/tng_pdf_model.py index 016a0cb..50904f3 100644 --- a/diffmah/tng_pdf_model.py +++ b/diffmah/tng_pdf_model.py @@ -1,5 +1,6 @@ """ """ + from collections import OrderedDict from jax import jit as jjit @@ -8,7 +9,7 @@ from jax import vmap from .defaults import MAH_K -from .utils import get_cholesky_from_params +from .utils import _sigmoid, get_cholesky_from_params TODAY = 13.8 LGT0 = jnp.log10(TODAY) @@ -57,12 +58,6 @@ ) -@jjit -def _sigmoid(x, logtc, k, ymin, ymax): - height_diff = ymax - ymin - return ymin + height_diff / (1.0 + lax.exp(-k * (x - logtc))) - - def _get_cov_scalar( log10_lge_lge, log10_lgl_lgl, diff --git a/diffmah/utils.py b/diffmah/utils.py index 5b39b6b..dc7627f 100644 --- a/diffmah/utils.py +++ b/diffmah/utils.py @@ -1,7 +1,8 @@ """Utility functions used throughout the package.""" + import numpy as np from jax import jit as jjit -from jax import lax +from jax import nn from jax import numpy as jnp from jax.example_libraries import optimizers as jax_opt @@ -25,7 +26,8 @@ def get_1d_arrays(*args, jax_arrays=False): return result -def jax_sigmoid(x, x0, k, ylo, yhi): +@jjit +def _sigmoid(x, x0, k, ylo, yhi): """Sigmoid function implemented w/ `jax.numpy.exp`. Parameters @@ -45,10 +47,11 @@ def jax_sigmoid(x, x0, k, ylo, yhi): ------- sigmoid : scalar or array-like, same shape as input """ - return ylo + (yhi - ylo) / (1 + lax.exp(-k * (x - x0))) + return ylo + (yhi - ylo) * nn.sigmoid(k * (x - x0)) -def jax_inverse_sigmoid(y, x0, k, ylo, yhi): +@jjit +def _inverse_sigmoid(y, x0, k, ylo, yhi): """Sigmoid function implemented w/ `jax.numpy.exp`. Parameters