-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #129 from ArgonneCPAC/kde_loss_dev_2d
Kde loss dev 2d
- Loading branch information
Showing
2 changed files
with
349 additions
and
0 deletions.
There are no files selected for viewing
155 changes: 155 additions & 0 deletions
155
diffmah/diffmahpop_kernels/kdescent_testing/kde2d_wrappers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
""" | ||
""" | ||
|
||
from jax import random as jran | ||
from jax import value_and_grad, vmap | ||
|
||
try: | ||
import kdescent | ||
except ImportError: | ||
pass | ||
from jax import jit as jjit | ||
from jax import numpy as jnp | ||
|
||
from .. import diffmahpop_params as dpp | ||
from .. import mc_diffmahpop_kernels as mdk | ||
|
||
N_T_PER_BIN = 5 | ||
LGSMAH_MIN = -15 | ||
|
||
|
||
@jjit | ||
def mc_diffmah_preds(diffmahpop_u_params, pred_data): | ||
diffmahpop_params = dpp.get_diffmahpop_params_from_u_params(diffmahpop_u_params) | ||
tarr, lgm_obs, t_obs, ran_key, lgt0 = pred_data | ||
_res = mdk._mc_diffmah_halo_sample( | ||
diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
) | ||
ftpt0 = _res[3] | ||
dmhdt_tpt0 = _res[5] | ||
log_mah_tpt0 = _res[6] | ||
dmhdt_tp = _res[7] | ||
log_mah_tp = _res[8] | ||
dmhdt_tpt0 = jnp.clip(dmhdt_tpt0, 10**LGSMAH_MIN) | ||
dmhdt_tp = jnp.clip(dmhdt_tp, 10**LGSMAH_MIN) | ||
lgsmah_tpt0 = jnp.log10(dmhdt_tpt0) - log_mah_tpt0 | ||
lgsmah_tp = jnp.log10(dmhdt_tp) - log_mah_tpt0 | ||
return lgsmah_tpt0, log_mah_tpt0, lgsmah_tp, log_mah_tp, ftpt0 | ||
|
||
|
||
@jjit | ||
def get_single_sample_self_fit_target_data( | ||
u_params, tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
): | ||
pred_data = tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
_res = mc_diffmah_preds(u_params, pred_data) | ||
lgsmah_tpt0, log_mah_tpt0, lgsmah_tp, log_mah_tp, ftpt0 = _res | ||
weights_target = jnp.concatenate((ftpt0, 1 - ftpt0)) | ||
lgsmah_target = jnp.concatenate((lgsmah_tpt0, lgsmah_tp)) | ||
log_mahs_target = jnp.concatenate((log_mah_tpt0, log_mah_tp)) | ||
X_target = jnp.array((lgsmah_target, log_mahs_target)).swapaxes(0, 1) | ||
return X_target, weights_target | ||
|
||
|
||
@jjit | ||
def single_sample_kde_loss_self_fit( | ||
diffmahpop_u_params, | ||
tarr, | ||
lgm_obs, | ||
t_obs, | ||
ran_key, | ||
lgt0, | ||
X_target, | ||
weights_target, | ||
): | ||
kcalc0 = kdescent.KCalc(X_target[:, :, 0], weights_target) | ||
kcalc1 = kdescent.KCalc(X_target[:, :, 1], weights_target) | ||
kcalc2 = kdescent.KCalc(X_target[:, :, 2], weights_target) | ||
kcalc3 = kdescent.KCalc(X_target[:, :, 3], weights_target) | ||
kcalc4 = kdescent.KCalc(X_target[:, :, 4], weights_target) | ||
|
||
ran_key, pred_key = jran.split(ran_key, 2) | ||
pred_data = tarr, lgm_obs, t_obs, pred_key, lgt0 | ||
_res = mc_diffmah_preds(diffmahpop_u_params, pred_data) | ||
dmhdt_tpt0, log_mah_tpt0, dmhdt_tp, log_mah_tp, ftpt0 = _res | ||
|
||
weights_pred = jnp.concatenate((ftpt0, 1 - ftpt0)) | ||
dmhdts_pred = jnp.concatenate((dmhdt_tpt0, dmhdt_tp)) | ||
log_mahs_pred = jnp.concatenate((log_mah_tpt0, log_mah_tp)) | ||
X_preds = jnp.array((dmhdts_pred, log_mahs_pred)).swapaxes(0, 1) | ||
|
||
kcalc_keys = jran.split(ran_key, N_T_PER_BIN) | ||
|
||
model_counts0, truth_counts0 = kcalc0.compare_kde_counts( | ||
kcalc_keys[0], X_preds[:, :, 0], weights_pred | ||
) | ||
model_counts1, truth_counts1 = kcalc1.compare_kde_counts( | ||
kcalc_keys[1], X_preds[:, :, 1], weights_pred | ||
) | ||
model_counts2, truth_counts2 = kcalc2.compare_kde_counts( | ||
kcalc_keys[2], X_preds[:, :, 2], weights_pred | ||
) | ||
model_counts3, truth_counts3 = kcalc3.compare_kde_counts( | ||
kcalc_keys[3], X_preds[:, :, 3], weights_pred | ||
) | ||
model_counts4, truth_counts4 = kcalc4.compare_kde_counts( | ||
kcalc_keys[4], X_preds[:, :, 4], weights_pred | ||
) | ||
|
||
diff0 = model_counts0 - truth_counts0 | ||
diff1 = model_counts1 - truth_counts1 | ||
diff2 = model_counts2 - truth_counts2 | ||
diff3 = model_counts3 - truth_counts3 | ||
diff4 = model_counts4 - truth_counts4 | ||
|
||
loss0 = jnp.mean(diff0**2) | ||
loss1 = jnp.mean(diff1**2) | ||
loss2 = jnp.mean(diff2**2) | ||
loss3 = jnp.mean(diff3**2) | ||
loss4 = jnp.mean(diff4**2) | ||
|
||
loss = loss0 + loss1 + loss2 + loss3 + loss4 | ||
return loss | ||
|
||
|
||
single_sample_kde_loss_and_grad_self_fit = jjit( | ||
value_and_grad(single_sample_kde_loss_self_fit) | ||
) | ||
|
||
_A = (None, 0, 0, 0, 0, None) | ||
get_multisample_self_fit_target_data = jjit( | ||
vmap(get_single_sample_self_fit_target_data, in_axes=_A) | ||
) | ||
|
||
_L = (None, 0, 0, 0, 0, None, 0, 0) | ||
_multisample_kde_loss_self_fit = jjit(vmap(single_sample_kde_loss_self_fit, in_axes=_L)) | ||
|
||
|
||
@jjit | ||
def multisample_kde_loss_self_fit( | ||
diffmahpop_u_params, | ||
tarr_matrix, | ||
lgmobsarr, | ||
tobsarr, | ||
ran_keys, | ||
lgt0, | ||
X_targets, | ||
weights_targets, | ||
): | ||
losses = _multisample_kde_loss_self_fit( | ||
diffmahpop_u_params, | ||
tarr_matrix, | ||
lgmobsarr, | ||
tobsarr, | ||
ran_keys, | ||
lgt0, | ||
X_targets, | ||
weights_targets, | ||
) | ||
loss = jnp.mean(losses) | ||
return loss | ||
|
||
|
||
multisample_kde_loss_and_grad_self_fit = jjit( | ||
value_and_grad(multisample_kde_loss_self_fit) | ||
) |
194 changes: 194 additions & 0 deletions
194
diffmah/diffmahpop_kernels/kdescent_testing/tests/test_kde2d_wrappers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
""" | ||
""" | ||
|
||
import numpy as np | ||
import pytest | ||
from jax import jit as jjit | ||
from jax import numpy as jnp | ||
from jax import random as jran | ||
|
||
from ... import diffmahpop_params as dpp | ||
from .. import kde2d_wrappers as k2w | ||
|
||
try: | ||
import kdescent # noqa | ||
|
||
HAS_KDESCENT = True | ||
except ImportError: | ||
HAS_KDESCENT = False | ||
|
||
T_MIN_FIT = 0.5 | ||
|
||
|
||
def test_mc_diffmah_preds(): | ||
ran_key = jran.key(0) | ||
t_0 = 13.8 | ||
lgt0 = np.log10(t_0) | ||
n_times = 5 | ||
|
||
n_tests = 20 | ||
for __ in range(n_tests): | ||
ran_key, m_key, t_key = jran.split(ran_key, 3) | ||
lgm_obs = jran.uniform(m_key, minval=9, maxval=16, shape=()) | ||
t_obs = jran.uniform(t_key, minval=3, maxval=t_0, shape=()) | ||
tarr = np.linspace(T_MIN_FIT, t_obs, n_times) | ||
pred_data = tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
_preds = k2w.mc_diffmah_preds(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, pred_data) | ||
for _x in _preds: | ||
assert np.all(np.isfinite(_x)) | ||
lgsmar_tpt0, log_mah_tpt0, lgsmar_tp, log_mah_tp, ftpt0 = _preds | ||
assert np.all(lgsmar_tpt0 < 20) | ||
assert np.all(lgsmar_tp < 20) | ||
assert np.all(log_mah_tpt0 < 20) | ||
assert np.all(log_mah_tp < 20) | ||
assert np.all(ftpt0 <= 1) | ||
assert np.all(ftpt0 >= 0) | ||
|
||
|
||
@pytest.mark.skipif("not HAS_KDESCENT") | ||
def test_single_sample_kde_loss_self_fit(): | ||
"""Enforce that single-sample loss has finite grads""" | ||
ran_key = jran.key(0) | ||
|
||
t_0 = 13.8 | ||
lgt0 = np.log10(t_0) | ||
DP = 0.5 | ||
|
||
n_tests = 100 | ||
for __ in range(n_tests): | ||
|
||
# Use random diffmahpop parameter to generate fiducial data | ||
u_p_fid_key, ran_key = jran.split(ran_key, 2) | ||
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,)) | ||
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)] | ||
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list)) | ||
|
||
ran_key, lgm_key, t_obs_key = jran.split(ran_key, 3) | ||
lgm_obs = jran.uniform(lgm_key, minval=11, maxval=15, shape=()) | ||
t_obs = jran.uniform(t_obs_key, minval=4, maxval=t_0, shape=()) | ||
|
||
tarr = np.linspace(T_MIN_FIT, t_obs, k2w.N_T_PER_BIN) | ||
|
||
_res = k2w.get_single_sample_self_fit_target_data( | ||
u_p_fid, tarr, lgm_obs, t_obs, ran_key, lgt0 | ||
) | ||
X_target, weights_target = _res | ||
|
||
# use default params as the initial guess | ||
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make( | ||
dpp.DEFAULT_DIFFMAHPOP_U_PARAMS | ||
) | ||
|
||
loss_data = tarr, lgm_obs, t_obs, ran_key, lgt0, X_target, weights_target | ||
loss = k2w.single_sample_kde_loss_self_fit(u_p_init, *loss_data) | ||
assert np.all(np.isfinite(loss)), (lgm_obs, t_obs) | ||
assert loss > 0 | ||
|
||
loss, grads = k2w.single_sample_kde_loss_and_grad_self_fit(u_p_init, *loss_data) | ||
assert np.all(np.isfinite(grads)), (lgm_obs, t_obs) | ||
|
||
|
||
@pytest.mark.skipif("not HAS_KDESCENT") | ||
def test_multisample_kde_loss_self_fit(): | ||
"""Enforce that multi-sample loss has finite grads""" | ||
ran_key = jran.key(0) | ||
|
||
# Use random diffmahpop parameter to generate fiducial data | ||
DP = 0.1 | ||
u_p_fid_key, ran_key = jran.split(ran_key, 2) | ||
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,)) | ||
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)] | ||
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list)) | ||
|
||
lgmobs_key, tobs_key, ran_key = jran.split(ran_key, 3) | ||
num_samples = 5 | ||
t_0 = 13.8 | ||
lgt0 = np.log10(t_0) | ||
lgmobsarr = jran.uniform(lgmobs_key, minval=11, maxval=15, shape=(num_samples,)) | ||
tobsarr = jran.uniform(tobs_key, minval=4, maxval=13, shape=(num_samples,)) | ||
num_target_redshifts_per_t_obs = 10 | ||
|
||
tarr_matrix = jnp.array( | ||
[jnp.linspace(T_MIN_FIT, t, num_target_redshifts_per_t_obs) for t in tobsarr] | ||
) | ||
_keys = jran.split(ran_key, num_samples * 2) | ||
_res = k2w.get_multisample_self_fit_target_data( | ||
u_p_fid, tarr_matrix, lgmobsarr, tobsarr, _keys[:num_samples], lgt0 | ||
) | ||
X_targets, weights_targets = _res | ||
loss_data = ( | ||
tarr_matrix, | ||
lgmobsarr, | ||
tobsarr, | ||
_keys[num_samples:], | ||
lgt0, | ||
X_targets, | ||
weights_targets, | ||
) | ||
# use default params as the initial guess | ||
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
loss = k2w.multisample_kde_loss_self_fit(u_p_init, *loss_data) | ||
assert np.all(np.isfinite(loss)) | ||
assert loss > 0 | ||
|
||
loss, grads = k2w.multisample_kde_loss_and_grad_self_fit(u_p_init, *loss_data) | ||
assert np.all(np.isfinite(grads)) | ||
|
||
|
||
@pytest.mark.skipif("not HAS_KDESCENT") | ||
def test_kdescent_adam_self_fit(): | ||
"""Enforce that kdescent.adam terminates without NaNs""" | ||
ran_key = jran.key(0) | ||
|
||
# Use random diffmahpop parameter to generate fiducial data | ||
DP = 0.5 | ||
u_p_fid_key, ran_key = jran.split(ran_key, 2) | ||
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,)) | ||
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)] | ||
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list)) | ||
|
||
lgmobs_key, tobs_key, ran_key = jran.split(ran_key, 3) | ||
num_samples = 5 | ||
t_0 = 13.8 | ||
lgt0 = np.log10(t_0) | ||
lgmobsarr = jran.uniform(lgmobs_key, minval=11, maxval=15, shape=(num_samples,)) | ||
tobsarr = jran.uniform(tobs_key, minval=4, maxval=13, shape=(num_samples,)) | ||
num_target_redshifts_per_t_obs = 10 | ||
|
||
tarr_matrix = jnp.array( | ||
[jnp.linspace(T_MIN_FIT, t, num_target_redshifts_per_t_obs) for t in tobsarr] | ||
) | ||
|
||
@jjit | ||
def kde_loss(u_p, randkey): | ||
_keys = jran.split(randkey, num_samples * 2) | ||
_res = k2w.get_multisample_self_fit_target_data( | ||
u_p_fid, tarr_matrix, lgmobsarr, tobsarr, _keys[:num_samples], lgt0 | ||
) | ||
X_targets, weights_targets = _res | ||
loss_data = ( | ||
tarr_matrix, | ||
lgmobsarr, | ||
tobsarr, | ||
_keys[num_samples:], | ||
lgt0, | ||
X_targets, | ||
weights_targets, | ||
) | ||
return k2w.multisample_kde_loss_self_fit(u_p, *loss_data) | ||
|
||
# use default params as the initial guess | ||
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS) | ||
|
||
adam_results = kdescent.adam( | ||
kde_loss, | ||
u_p_init, | ||
nsteps=5, | ||
learning_rate=0.1, | ||
randkey=12345, | ||
) | ||
u_p_best = adam_results[-1] | ||
assert np.all(np.isfinite(u_p_best)) |