Skip to content

Commit

Permalink
switch to jax.random for failing test_trimmed_mean_agrees_with_scipy
Browse files Browse the repository at this point in the history
  • Loading branch information
aphearin committed Jun 6, 2024
1 parent 55c8c12 commit 6bf517f
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions diffmah/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from jax import jit as jax_jit
from jax import numpy as jax_np
from jax import random as jran
from jax import value_and_grad

from ..utils import (
Expand Down Expand Up @@ -126,9 +127,11 @@ def test_get_cholesky_from_params1():
def test_trimmed_mean_agrees_with_scipy():
from scipy.stats.mstats import trimmed_mean as trimmed_mean_scipy

ran_key = jran.key(0)
ptest = 0.1, 0.2, 0.3
for p in ptest:
x = np.random.normal(loc=0, scale=1, size=20_000)
ran_key, test_key = jran.split(ran_key, 2)
x = jran.normal(test_key, shape=(20_000,))
mu_p10 = trimmed_mean(x, p)
mu_p10_scipy = trimmed_mean_scipy(x, p)
assert np.allclose(mu_p10, mu_p10_scipy, rtol=0.01)
Expand All @@ -138,9 +141,11 @@ def test_trimmed_mean_and_variance_agrees_with_scipy():
from scipy.stats.mstats import trimmed_mean as trimmed_mean_scipy
from scipy.stats.mstats import trimmed_var as trimmed_var_scipy

ran_key = jran.key(0)
ptest = 0.01, 0.1, 0.2, 0.3
for p in ptest:
x = np.random.normal(loc=0, scale=1, size=20_000)
ran_key, test_key = jran.split(ran_key, 2)
x = jran.normal(test_key, shape=(20_000,))
mu_p10, var_p10 = trimmed_mean_and_variance(x, p)
mu_p10_scipy = trimmed_mean_scipy(x, p)
var_p10_scipy = trimmed_var_scipy(x, p)
Expand All @@ -149,7 +154,8 @@ def test_trimmed_mean_and_variance_agrees_with_scipy():


def test_trimmed_mean_and_variance_consistency():
x = np.random.normal(loc=0, scale=1, size=20_000)
ran_key = jran.key(0)
x = jran.normal(ran_key, shape=(20_000,))
mu, var = trimmed_mean_and_variance(x, 0.1)
mu2 = trimmed_mean(x, 0.1)
assert np.allclose(mu, mu2, rtol=1e-4)

0 comments on commit 6bf517f

Please sign in to comment.