diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c0_kernels.py b/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c0_kernels.py index 054cd81..be58fa2 100644 --- a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c0_kernels.py +++ b/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c0_kernels.py @@ -41,7 +41,7 @@ def _pred_c0_kern(params, t_obs, t_peak): c0_ytp, c0_ylo, clip_c0, clip_c1 = params pred_c0 = _sig_slope(t_obs, XTP, c0_ytp, GLOBAL_X0, GLOBAL_K, c0_ylo, 0.0) clip = clip_c0 + clip_c1 * t_peak - pred_c0 = jnp.clip(pred_c0, a_min=clip) + pred_c0 = jnp.clip(pred_c0, min=clip) return pred_c0 diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c1_kernels.py b/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c1_kernels.py index 8f43c7e..0059199 100644 --- a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c1_kernels.py +++ b/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c1_kernels.py @@ -46,7 +46,7 @@ def _pred_c1_kern(params, t_obs, t_peak): pred_c1 = _sig_slope(t_obs, XTP, c1_ytp, GLOBAL_X0, GLOBAL_K, c1_ylo, 0.0) clip = _sigmoid(t_peak, c1_clip_x0, CLIP_TP_K, c1_clip_ylo, c1_clip_yhi) - pred_c1 = jnp.clip(pred_c1, a_min=clip) + pred_c1 = jnp.clip(pred_c1, min=clip) return pred_c1 diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_pop.py b/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_pop.py index 0c0e787..b09c0f1 100644 --- a/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_pop.py +++ b/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_pop.py @@ -2,6 +2,7 @@ """ import numpy as np +from jax import random as jran from .. import logm0_pop as m0pop @@ -66,3 +67,20 @@ def test_default_params_are_in_bounds(): val = getattr(m0pop.DEFAULT_LOGM0POP_PARAMS, key) bound = getattr(m0pop.LGM0POP_BOUNDS, key) assert bound[0] < val < bound[1] + + +def test_pred_logm0_kern(): + ran_key = jran.key(0) + n_tests = 1_000 + for __ in range(n_tests): + lgm_key, t_obs_key, t_peak_key, ran_key = jran.split(ran_key, 4) + lgm_obs = jran.uniform(lgm_key, minval=5, maxval=16, shape=()) + t_obs = jran.uniform(t_obs_key, minval=1, maxval=20, shape=()) + t_peak = jran.uniform(t_peak_key, minval=1.5, maxval=20, shape=()) + lgm0 = m0pop._pred_logm0_kern( + m0pop.DEFAULT_LOGM0POP_PARAMS, lgm_obs, t_obs, t_peak + ) + assert lgm0.shape == () + assert np.isfinite(lgm0) + assert lgm0 > 0 + assert lgm0 < 20