Skip to content

Commit

Permalink
Merge pull request #128 from ArgonneCPAC/kde_loss_dev_1d
Browse files Browse the repository at this point in the history
Kde loss dev 1d
  • Loading branch information
aphearin committed Jul 18, 2024
2 parents 8b9438c + a5c6ce3 commit 2efee6a
Show file tree
Hide file tree
Showing 3 changed files with 344 additions and 0 deletions.
146 changes: 146 additions & 0 deletions diffmah/diffmahpop_kernels/kdescent_testing/kde1d_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
"""

from jax import random as jran
from jax import value_and_grad, vmap

try:
import kdescent
except ImportError:
pass
from jax import jit as jjit
from jax import numpy as jnp

from .. import diffmahpop_params as dpp
from .. import mc_diffmahpop_kernels as mdk

N_T_PER_BIN = 5


@jjit
def mc_diffmah_preds(diffmahpop_u_params, pred_data):
diffmahpop_params = dpp.get_diffmahpop_params_from_u_params(diffmahpop_u_params)
tarr, lgm_obs, t_obs, ran_key, lgt0 = pred_data
_res = mdk._mc_diffmah_halo_sample(
diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0
)
ftpt0 = _res[3]
log_mah_tpt0 = _res[6]
log_mah_tp = _res[8]
return log_mah_tpt0, log_mah_tp, ftpt0


@jjit
def single_sample_kde_loss_self_fit(
diffmahpop_u_params,
tarr,
lgm_obs,
t_obs,
ran_key,
lgt0,
log_mahs_target,
weights_target,
):
s = (-1, 1)
kcalc0 = kdescent.KCalc(log_mahs_target[:, 0].reshape(s), weights_target)
kcalc1 = kdescent.KCalc(log_mahs_target[:, 1].reshape(s), weights_target)
kcalc2 = kdescent.KCalc(log_mahs_target[:, 2].reshape(s), weights_target)
kcalc3 = kdescent.KCalc(log_mahs_target[:, 3].reshape(s), weights_target)
kcalc4 = kdescent.KCalc(log_mahs_target[:, 4].reshape(s), weights_target)

ran_key, pred_key = jran.split(ran_key, 2)
pred_data = tarr, lgm_obs, t_obs, pred_key, lgt0
_res = mc_diffmah_preds(diffmahpop_u_params, pred_data)
log_mah_tpt0, log_mah_tp, ftpt0 = _res

log_mahs_pred = jnp.concatenate((log_mah_tpt0, log_mah_tp))
weights_pred = jnp.concatenate((ftpt0, 1 - ftpt0))

kcalc_keys = jran.split(ran_key, N_T_PER_BIN)

model_counts0, truth_counts0 = kcalc0.compare_kde_counts(
kcalc_keys[0], log_mahs_pred[:, 0].reshape(s), weights_pred
)
model_counts1, truth_counts1 = kcalc1.compare_kde_counts(
kcalc_keys[1], log_mahs_pred[:, 1].reshape(s), weights_pred
)
model_counts2, truth_counts2 = kcalc2.compare_kde_counts(
kcalc_keys[2], log_mahs_pred[:, 2].reshape(s), weights_pred
)
model_counts3, truth_counts3 = kcalc3.compare_kde_counts(
kcalc_keys[3], log_mahs_pred[:, 3].reshape(s), weights_pred
)
model_counts4, truth_counts4 = kcalc4.compare_kde_counts(
kcalc_keys[4], log_mahs_pred[:, 4].reshape(s), weights_pred
)

diff0 = model_counts0 - truth_counts0
diff1 = model_counts1 - truth_counts1
diff2 = model_counts2 - truth_counts2
diff3 = model_counts3 - truth_counts3
diff4 = model_counts4 - truth_counts4

loss0 = jnp.mean(diff0**2)
loss1 = jnp.mean(diff1**2)
loss2 = jnp.mean(diff2**2)
loss3 = jnp.mean(diff3**2)
loss4 = jnp.mean(diff4**2)

loss = loss0 + loss1 + loss2 + loss3 + loss4
return loss


single_sample_kde_loss_and_grad_self_fit = jjit(
value_and_grad(single_sample_kde_loss_self_fit)
)


@jjit
def get_single_sample_self_fit_target_data(
u_params, tarr, lgm_obs, t_obs, ran_key, lgt0
):
pred_data = tarr, lgm_obs, t_obs, ran_key, lgt0
_res = mc_diffmah_preds(u_params, pred_data)
log_mah_tpt0, log_mah_tp, ftpt0 = _res
log_mahs_target = jnp.concatenate((log_mah_tpt0, log_mah_tp))
weights_target = jnp.concatenate((ftpt0, 1 - ftpt0))
return log_mahs_target, weights_target


_A = (None, 0, 0, 0, 0, None)
get_multisample_self_fit_target_data = jjit(
vmap(get_single_sample_self_fit_target_data, in_axes=_A)
)

_L = (None, 0, 0, 0, 0, None, 0, 0)
_multisample_kde_loss_self_fit = jjit(vmap(single_sample_kde_loss_self_fit, in_axes=_L))


@jjit
def multisample_kde_loss_self_fit(
diffmahpop_u_params,
tarr_matrix,
lgmobsarr,
tobsarr,
ran_keys,
lgt0,
log_mahs_targets,
weights_targets,
):
losses = _multisample_kde_loss_self_fit(
diffmahpop_u_params,
tarr_matrix,
lgmobsarr,
tobsarr,
ran_keys,
lgt0,
log_mahs_targets,
weights_targets,
)
loss = jnp.mean(losses)
return loss


multisample_kde_loss_and_grad_self_fit = jjit(
value_and_grad(multisample_kde_loss_self_fit)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""
"""

import numpy as np
import pytest
from jax import jit as jjit
from jax import numpy as jnp
from jax import random as jran

from ... import diffmahpop_params as dpp
from .. import kde1d_wrappers as k1w

try:
import kdescent # noqa

HAS_KDESCENT = True
except ImportError:
HAS_KDESCENT = False

T_MIN_FIT = 0.5


def test_mc_diffmah_preds():
ran_key = jran.key(0)
t_0 = 13.8
lgt0 = np.log10(t_0)
n_times = 5

n_tests = 20
for __ in range(n_tests):
ran_key, m_key, t_key = jran.split(ran_key, 3)
lgm_obs = jran.uniform(m_key, minval=9, maxval=16, shape=())
t_obs = jran.uniform(t_key, minval=3, maxval=t_0, shape=())
tarr = np.linspace(T_MIN_FIT, t_obs, n_times)
pred_data = tarr, lgm_obs, t_obs, ran_key, lgt0
_preds = k1w.mc_diffmah_preds(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, pred_data)
for _x in _preds:
assert np.all(np.isfinite(_x))
log_mah_tpt0, log_mah_tp, ftpt0 = _preds
assert np.all(log_mah_tpt0 < 20)
assert np.all(log_mah_tp < 20)
assert np.all(ftpt0 <= 1)
assert np.all(ftpt0 >= 0)


@pytest.mark.skipif("not HAS_KDESCENT")
def test_single_sample_kde_loss_self_fit():
"""Enforce that single-sample loss has finite grads"""
ran_key = jran.key(0)

t_0 = 13.8
lgt0 = np.log10(t_0)
DP = 0.5

n_tests = 100
for __ in range(n_tests):

# Use random diffmahpop parameter to generate fiducial data
u_p_fid_key, ran_key = jran.split(ran_key, 2)
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,))
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)]
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list))

ran_key, lgm_key, t_obs_key = jran.split(ran_key, 3)
lgm_obs = jran.uniform(lgm_key, minval=11, maxval=15, shape=())
t_obs = jran.uniform(t_obs_key, minval=4, maxval=t_0, shape=())

tarr = np.linspace(T_MIN_FIT, t_obs, k1w.N_T_PER_BIN)

_res = k1w.get_single_sample_self_fit_target_data(
u_p_fid, tarr, lgm_obs, t_obs, ran_key, lgt0
)
log_mahs_target, weights_target = _res

# use default params as the initial guess
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(
dpp.DEFAULT_DIFFMAHPOP_U_PARAMS
)

loss_data = tarr, lgm_obs, t_obs, ran_key, lgt0, log_mahs_target, weights_target
loss = k1w.single_sample_kde_loss_self_fit(u_p_init, *loss_data)
assert np.all(np.isfinite(loss)), (lgm_obs, t_obs)
assert loss > 0

loss, grads = k1w.single_sample_kde_loss_and_grad_self_fit(u_p_init, *loss_data)
assert np.all(np.isfinite(grads)), (lgm_obs, t_obs)


@pytest.mark.skipif("not HAS_KDESCENT")
def test_multisample_kde_loss_self_fit():
"""Enforce that multi-sample loss has finite grads"""
ran_key = jran.key(0)

# Use random diffmahpop parameter to generate fiducial data
DP = 0.1
u_p_fid_key, ran_key = jran.split(ran_key, 2)
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,))
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)]
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list))

lgmobs_key, tobs_key, ran_key = jran.split(ran_key, 3)
num_samples = 5
t_0 = 13.8
lgt0 = np.log10(t_0)
lgmobsarr = jran.uniform(lgmobs_key, minval=11, maxval=15, shape=(num_samples,))
tobsarr = jran.uniform(tobs_key, minval=4, maxval=13, shape=(num_samples,))
num_target_redshifts_per_t_obs = 10

tarr_matrix = jnp.array(
[jnp.linspace(T_MIN_FIT, t, num_target_redshifts_per_t_obs) for t in tobsarr]
)
_keys = jran.split(ran_key, num_samples * 2)
_res = k1w.get_multisample_self_fit_target_data(
u_p_fid, tarr_matrix, lgmobsarr, tobsarr, _keys[:num_samples], lgt0
)
log_mahs_targets, weights_targets = _res
loss_data = (
tarr_matrix,
lgmobsarr,
tobsarr,
_keys[num_samples:],
lgt0,
log_mahs_targets,
weights_targets,
)
# use default params as the initial guess
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)
loss = k1w.multisample_kde_loss_self_fit(u_p_init, *loss_data)
assert np.all(np.isfinite(loss))
assert loss > 0

loss, grads = k1w.multisample_kde_loss_and_grad_self_fit(u_p_init, *loss_data)
assert np.all(np.isfinite(grads))


@pytest.mark.skipif("not HAS_KDESCENT")
def test_kdescent_adam_self_fit():
"""Enforce that kdescent.adam terminates without NaNs"""
ran_key = jran.key(0)

# Use random diffmahpop parameter to generate fiducial data
DP = 0.5
u_p_fid_key, ran_key = jran.split(ran_key, 2)
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,))
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)]
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list))

lgmobs_key, tobs_key, ran_key = jran.split(ran_key, 3)
num_samples = 5
t_0 = 13.8
lgt0 = np.log10(t_0)
lgmobsarr = jran.uniform(lgmobs_key, minval=11, maxval=15, shape=(num_samples,))
tobsarr = jran.uniform(tobs_key, minval=4, maxval=13, shape=(num_samples,))
num_target_redshifts_per_t_obs = 10

tarr_matrix = jnp.array(
[jnp.linspace(T_MIN_FIT, t, num_target_redshifts_per_t_obs) for t in tobsarr]
)

@jjit
def kde_loss(u_p, randkey):
_keys = jran.split(randkey, num_samples * 2)
_res = k1w.get_multisample_self_fit_target_data(
u_p_fid, tarr_matrix, lgmobsarr, tobsarr, _keys[:num_samples], lgt0
)
log_mahs_targets, weights_targets = _res
loss_data = (
tarr_matrix,
lgmobsarr,
tobsarr,
_keys[num_samples:],
lgt0,
log_mahs_targets,
weights_targets,
)
return k1w.multisample_kde_loss_self_fit(u_p, *loss_data)

# use default params as the initial guess
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)

adam_results = kdescent.adam(
kde_loss,
u_p_init,
nsteps=5,
learning_rate=0.1,
randkey=12345,
)
u_p_best = adam_results[-1]
assert np.all(np.isfinite(u_p_best))
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
T_MIN_FIT = 0.5


@pytest.mark.skip
@pytest.mark.xfail
@pytest.mark.skipif("not HAS_KDESCENT")
def test_single_sample_kde_loss_self_fit():
"""Enforce that single-sample loss has finite grads"""
Expand Down Expand Up @@ -67,6 +69,8 @@ def test_single_sample_kde_loss_self_fit():
assert np.all(np.isfinite(grads)), (lgm_obs, t_obs)


@pytest.mark.skip
@pytest.mark.xfail
@pytest.mark.skipif("not HAS_KDESCENT")
def test_multisample_kde_loss_self_fit():
"""Enforce that multi-sample loss has finite grads"""
Expand Down Expand Up @@ -115,6 +119,8 @@ def test_multisample_kde_loss_self_fit():
assert np.all(np.isfinite(grads))


@pytest.mark.skip
@pytest.mark.xfail
@pytest.mark.skipif("not HAS_KDESCENT")
def test_kdescent_adam_self_fit():
"""Enforce that kdescent.adam terminates without NaNs"""
Expand Down

0 comments on commit 2efee6a

Please sign in to comment.