-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #128 from ArgonneCPAC/kde_loss_dev_1d
Kde loss dev 1d
- Loading branch information
Showing
3 changed files
with
344 additions
and
0 deletions.
There are no files selected for viewing
146 changes: 146 additions & 0 deletions
146
diffmah/diffmahpop_kernels/kdescent_testing/kde1d_wrappers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
""" | ||
""" | ||
|
||
from jax import random as jran | ||
from jax import value_and_grad, vmap | ||
|
||
try: | ||
import kdescent | ||
except ImportError: | ||
pass | ||
from jax import jit as jjit | ||
from jax import numpy as jnp | ||
|
||
from .. import diffmahpop_params as dpp | ||
from .. import mc_diffmahpop_kernels as mdk | ||
|
||
N_T_PER_BIN = 5 | ||
|
||
|
||
@jjit | ||
def mc_diffmah_preds(diffmahpop_u_params, pred_data): | ||
diffmahpop_params = dpp.get_diffmahpop_params_from_u_params(diffmahpop_u_params) | ||
tarr, lgm_obs, t_obs, ran_key, lgt0 = pred_data | ||
_res = mdk._mc_diffmah_halo_sample( | ||
diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
) | ||
ftpt0 = _res[3] | ||
log_mah_tpt0 = _res[6] | ||
log_mah_tp = _res[8] | ||
return log_mah_tpt0, log_mah_tp, ftpt0 | ||
|
||
|
||
@jjit | ||
def single_sample_kde_loss_self_fit( | ||
diffmahpop_u_params, | ||
tarr, | ||
lgm_obs, | ||
t_obs, | ||
ran_key, | ||
lgt0, | ||
log_mahs_target, | ||
weights_target, | ||
): | ||
s = (-1, 1) | ||
kcalc0 = kdescent.KCalc(log_mahs_target[:, 0].reshape(s), weights_target) | ||
kcalc1 = kdescent.KCalc(log_mahs_target[:, 1].reshape(s), weights_target) | ||
kcalc2 = kdescent.KCalc(log_mahs_target[:, 2].reshape(s), weights_target) | ||
kcalc3 = kdescent.KCalc(log_mahs_target[:, 3].reshape(s), weights_target) | ||
kcalc4 = kdescent.KCalc(log_mahs_target[:, 4].reshape(s), weights_target) | ||
|
||
ran_key, pred_key = jran.split(ran_key, 2) | ||
pred_data = tarr, lgm_obs, t_obs, pred_key, lgt0 | ||
_res = mc_diffmah_preds(diffmahpop_u_params, pred_data) | ||
log_mah_tpt0, log_mah_tp, ftpt0 = _res | ||
|
||
log_mahs_pred = jnp.concatenate((log_mah_tpt0, log_mah_tp)) | ||
weights_pred = jnp.concatenate((ftpt0, 1 - ftpt0)) | ||
|
||
kcalc_keys = jran.split(ran_key, N_T_PER_BIN) | ||
|
||
model_counts0, truth_counts0 = kcalc0.compare_kde_counts( | ||
kcalc_keys[0], log_mahs_pred[:, 0].reshape(s), weights_pred | ||
) | ||
model_counts1, truth_counts1 = kcalc1.compare_kde_counts( | ||
kcalc_keys[1], log_mahs_pred[:, 1].reshape(s), weights_pred | ||
) | ||
model_counts2, truth_counts2 = kcalc2.compare_kde_counts( | ||
kcalc_keys[2], log_mahs_pred[:, 2].reshape(s), weights_pred | ||
) | ||
model_counts3, truth_counts3 = kcalc3.compare_kde_counts( | ||
kcalc_keys[3], log_mahs_pred[:, 3].reshape(s), weights_pred | ||
) | ||
model_counts4, truth_counts4 = kcalc4.compare_kde_counts( | ||
kcalc_keys[4], log_mahs_pred[:, 4].reshape(s), weights_pred | ||
) | ||
|
||
diff0 = model_counts0 - truth_counts0 | ||
diff1 = model_counts1 - truth_counts1 | ||
diff2 = model_counts2 - truth_counts2 | ||
diff3 = model_counts3 - truth_counts3 | ||
diff4 = model_counts4 - truth_counts4 | ||
|
||
loss0 = jnp.mean(diff0**2) | ||
loss1 = jnp.mean(diff1**2) | ||
loss2 = jnp.mean(diff2**2) | ||
loss3 = jnp.mean(diff3**2) | ||
loss4 = jnp.mean(diff4**2) | ||
|
||
loss = loss0 + loss1 + loss2 + loss3 + loss4 | ||
return loss | ||
|
||
|
||
single_sample_kde_loss_and_grad_self_fit = jjit( | ||
value_and_grad(single_sample_kde_loss_self_fit) | ||
) | ||
|
||
|
||
@jjit | ||
def get_single_sample_self_fit_target_data( | ||
u_params, tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
): | ||
pred_data = tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
_res = mc_diffmah_preds(u_params, pred_data) | ||
log_mah_tpt0, log_mah_tp, ftpt0 = _res | ||
log_mahs_target = jnp.concatenate((log_mah_tpt0, log_mah_tp)) | ||
weights_target = jnp.concatenate((ftpt0, 1 - ftpt0)) | ||
return log_mahs_target, weights_target | ||
|
||
|
||
_A = (None, 0, 0, 0, 0, None) | ||
get_multisample_self_fit_target_data = jjit( | ||
vmap(get_single_sample_self_fit_target_data, in_axes=_A) | ||
) | ||
|
||
_L = (None, 0, 0, 0, 0, None, 0, 0) | ||
_multisample_kde_loss_self_fit = jjit(vmap(single_sample_kde_loss_self_fit, in_axes=_L)) | ||
|
||
|
||
@jjit | ||
def multisample_kde_loss_self_fit( | ||
diffmahpop_u_params, | ||
tarr_matrix, | ||
lgmobsarr, | ||
tobsarr, | ||
ran_keys, | ||
lgt0, | ||
log_mahs_targets, | ||
weights_targets, | ||
): | ||
losses = _multisample_kde_loss_self_fit( | ||
diffmahpop_u_params, | ||
tarr_matrix, | ||
lgmobsarr, | ||
tobsarr, | ||
ran_keys, | ||
lgt0, | ||
log_mahs_targets, | ||
weights_targets, | ||
) | ||
loss = jnp.mean(losses) | ||
return loss | ||
|
||
|
||
multisample_kde_loss_and_grad_self_fit = jjit( | ||
value_and_grad(multisample_kde_loss_self_fit) | ||
) |
192 changes: 192 additions & 0 deletions
192
diffmah/diffmahpop_kernels/kdescent_testing/tests/test_kde1d_wrappers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
""" | ||
""" | ||
|
||
import numpy as np | ||
import pytest | ||
from jax import jit as jjit | ||
from jax import numpy as jnp | ||
from jax import random as jran | ||
|
||
from ... import diffmahpop_params as dpp | ||
from .. import kde1d_wrappers as k1w | ||
|
||
try: | ||
import kdescent # noqa | ||
|
||
HAS_KDESCENT = True | ||
except ImportError: | ||
HAS_KDESCENT = False | ||
|
||
T_MIN_FIT = 0.5 | ||
|
||
|
||
def test_mc_diffmah_preds(): | ||
ran_key = jran.key(0) | ||
t_0 = 13.8 | ||
lgt0 = np.log10(t_0) | ||
n_times = 5 | ||
|
||
n_tests = 20 | ||
for __ in range(n_tests): | ||
ran_key, m_key, t_key = jran.split(ran_key, 3) | ||
lgm_obs = jran.uniform(m_key, minval=9, maxval=16, shape=()) | ||
t_obs = jran.uniform(t_key, minval=3, maxval=t_0, shape=()) | ||
tarr = np.linspace(T_MIN_FIT, t_obs, n_times) | ||
pred_data = tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
_preds = k1w.mc_diffmah_preds(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, pred_data) | ||
for _x in _preds: | ||
assert np.all(np.isfinite(_x)) | ||
log_mah_tpt0, log_mah_tp, ftpt0 = _preds | ||
assert np.all(log_mah_tpt0 < 20) | ||
assert np.all(log_mah_tp < 20) | ||
assert np.all(ftpt0 <= 1) | ||
assert np.all(ftpt0 >= 0) | ||
|
||
|
||
@pytest.mark.skipif("not HAS_KDESCENT") | ||
def test_single_sample_kde_loss_self_fit(): | ||
"""Enforce that single-sample loss has finite grads""" | ||
ran_key = jran.key(0) | ||
|
||
t_0 = 13.8 | ||
lgt0 = np.log10(t_0) | ||
DP = 0.5 | ||
|
||
n_tests = 100 | ||
for __ in range(n_tests): | ||
|
||
# Use random diffmahpop parameter to generate fiducial data | ||
u_p_fid_key, ran_key = jran.split(ran_key, 2) | ||
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,)) | ||
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)] | ||
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list)) | ||
|
||
ran_key, lgm_key, t_obs_key = jran.split(ran_key, 3) | ||
lgm_obs = jran.uniform(lgm_key, minval=11, maxval=15, shape=()) | ||
t_obs = jran.uniform(t_obs_key, minval=4, maxval=t_0, shape=()) | ||
|
||
tarr = np.linspace(T_MIN_FIT, t_obs, k1w.N_T_PER_BIN) | ||
|
||
_res = k1w.get_single_sample_self_fit_target_data( | ||
u_p_fid, tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
) | ||
log_mahs_target, weights_target = _res | ||
|
||
# use default params as the initial guess | ||
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make( | ||
dpp.DEFAULT_DIFFMAHPOP_U_PARAMS | ||
) | ||
|
||
loss_data = tarr, lgm_obs, t_obs, ran_key, lgt0, log_mahs_target, weights_target | ||
loss = k1w.single_sample_kde_loss_self_fit(u_p_init, *loss_data) | ||
assert np.all(np.isfinite(loss)), (lgm_obs, t_obs) | ||
assert loss > 0 | ||
|
||
loss, grads = k1w.single_sample_kde_loss_and_grad_self_fit(u_p_init, *loss_data) | ||
assert np.all(np.isfinite(grads)), (lgm_obs, t_obs) | ||
|
||
|
||
@pytest.mark.skipif("not HAS_KDESCENT") | ||
def test_multisample_kde_loss_self_fit(): | ||
"""Enforce that multi-sample loss has finite grads""" | ||
ran_key = jran.key(0) | ||
|
||
# Use random diffmahpop parameter to generate fiducial data | ||
DP = 0.1 | ||
u_p_fid_key, ran_key = jran.split(ran_key, 2) | ||
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,)) | ||
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)] | ||
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list)) | ||
|
||
lgmobs_key, tobs_key, ran_key = jran.split(ran_key, 3) | ||
num_samples = 5 | ||
t_0 = 13.8 | ||
lgt0 = np.log10(t_0) | ||
lgmobsarr = jran.uniform(lgmobs_key, minval=11, maxval=15, shape=(num_samples,)) | ||
tobsarr = jran.uniform(tobs_key, minval=4, maxval=13, shape=(num_samples,)) | ||
num_target_redshifts_per_t_obs = 10 | ||
|
||
tarr_matrix = jnp.array( | ||
[jnp.linspace(T_MIN_FIT, t, num_target_redshifts_per_t_obs) for t in tobsarr] | ||
) | ||
_keys = jran.split(ran_key, num_samples * 2) | ||
_res = k1w.get_multisample_self_fit_target_data( | ||
u_p_fid, tarr_matrix, lgmobsarr, tobsarr, _keys[:num_samples], lgt0 | ||
) | ||
log_mahs_targets, weights_targets = _res | ||
loss_data = ( | ||
tarr_matrix, | ||
lgmobsarr, | ||
tobsarr, | ||
_keys[num_samples:], | ||
lgt0, | ||
log_mahs_targets, | ||
weights_targets, | ||
) | ||
# use default params as the initial guess | ||
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
loss = k1w.multisample_kde_loss_self_fit(u_p_init, *loss_data) | ||
assert np.all(np.isfinite(loss)) | ||
assert loss > 0 | ||
|
||
loss, grads = k1w.multisample_kde_loss_and_grad_self_fit(u_p_init, *loss_data) | ||
assert np.all(np.isfinite(grads)) | ||
|
||
|
||
@pytest.mark.skipif("not HAS_KDESCENT") | ||
def test_kdescent_adam_self_fit(): | ||
"""Enforce that kdescent.adam terminates without NaNs""" | ||
ran_key = jran.key(0) | ||
|
||
# Use random diffmahpop parameter to generate fiducial data | ||
DP = 0.5 | ||
u_p_fid_key, ran_key = jran.split(ran_key, 2) | ||
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,)) | ||
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)] | ||
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list)) | ||
|
||
lgmobs_key, tobs_key, ran_key = jran.split(ran_key, 3) | ||
num_samples = 5 | ||
t_0 = 13.8 | ||
lgt0 = np.log10(t_0) | ||
lgmobsarr = jran.uniform(lgmobs_key, minval=11, maxval=15, shape=(num_samples,)) | ||
tobsarr = jran.uniform(tobs_key, minval=4, maxval=13, shape=(num_samples,)) | ||
num_target_redshifts_per_t_obs = 10 | ||
|
||
tarr_matrix = jnp.array( | ||
[jnp.linspace(T_MIN_FIT, t, num_target_redshifts_per_t_obs) for t in tobsarr] | ||
) | ||
|
||
@jjit | ||
def kde_loss(u_p, randkey): | ||
_keys = jran.split(randkey, num_samples * 2) | ||
_res = k1w.get_multisample_self_fit_target_data( | ||
u_p_fid, tarr_matrix, lgmobsarr, tobsarr, _keys[:num_samples], lgt0 | ||
) | ||
log_mahs_targets, weights_targets = _res | ||
loss_data = ( | ||
tarr_matrix, | ||
lgmobsarr, | ||
tobsarr, | ||
_keys[num_samples:], | ||
lgt0, | ||
log_mahs_targets, | ||
weights_targets, | ||
) | ||
return k1w.multisample_kde_loss_self_fit(u_p, *loss_data) | ||
|
||
# use default params as the initial guess | ||
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
|
||
adam_results = kdescent.adam( | ||
kde_loss, | ||
u_p_init, | ||
nsteps=5, | ||
learning_rate=0.1, | ||
randkey=12345, | ||
) | ||
u_p_best = adam_results[-1] | ||
assert np.all(np.isfinite(u_p_best)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters