diff --git a/diffmah/tests/test_utils.py b/diffmah/tests/test_utils.py index 9805212..0b51412 100644 --- a/diffmah/tests/test_utils.py +++ b/diffmah/tests/test_utils.py @@ -4,6 +4,7 @@ import numpy as np from jax import jit as jax_jit from jax import numpy as jax_np +from jax import random as jran from jax import value_and_grad from ..utils import ( @@ -126,9 +127,11 @@ def test_get_cholesky_from_params1(): def test_trimmed_mean_agrees_with_scipy(): from scipy.stats.mstats import trimmed_mean as trimmed_mean_scipy + ran_key = jran.key(0) ptest = 0.1, 0.2, 0.3 for p in ptest: - x = np.random.normal(loc=0, scale=1, size=20_000) + ran_key, test_key = jran.split(ran_key, 2) + x = jran.normal(test_key, shape=(20_000,)) mu_p10 = trimmed_mean(x, p) mu_p10_scipy = trimmed_mean_scipy(x, p) assert np.allclose(mu_p10, mu_p10_scipy, rtol=0.01) @@ -138,9 +141,11 @@ def test_trimmed_mean_and_variance_agrees_with_scipy(): from scipy.stats.mstats import trimmed_mean as trimmed_mean_scipy from scipy.stats.mstats import trimmed_var as trimmed_var_scipy + ran_key = jran.key(0) ptest = 0.01, 0.1, 0.2, 0.3 for p in ptest: - x = np.random.normal(loc=0, scale=1, size=20_000) + ran_key, test_key = jran.split(ran_key, 2) + x = jran.normal(test_key, shape=(20_000,)) mu_p10, var_p10 = trimmed_mean_and_variance(x, p) mu_p10_scipy = trimmed_mean_scipy(x, p) var_p10_scipy = trimmed_var_scipy(x, p) @@ -149,7 +154,8 @@ def test_trimmed_mean_and_variance_agrees_with_scipy(): def test_trimmed_mean_and_variance_consistency(): - x = np.random.normal(loc=0, scale=1, size=20_000) + ran_key = jran.key(0) + x = jran.normal(ran_key, shape=(20_000,)) mu, var = trimmed_mean_and_variance(x, 0.1) mu2 = trimmed_mean(x, 0.1) assert np.allclose(mu, mu2, rtol=1e-4)