Skip to content

Commit

Permalink
Merge pull request #129 from ArgonneCPAC/kde_loss_dev_2d
Browse files Browse the repository at this point in the history
Kde loss dev 2d
  • Loading branch information
aphearin committed Jul 22, 2024
2 parents 2efee6a + 0ade3b8 commit ba32dde
Show file tree
Hide file tree
Showing 2 changed files with 349 additions and 0 deletions.
155 changes: 155 additions & 0 deletions diffmah/diffmahpop_kernels/kdescent_testing/kde2d_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
"""

from jax import random as jran
from jax import value_and_grad, vmap

try:
import kdescent
except ImportError:
pass
from jax import jit as jjit
from jax import numpy as jnp

from .. import diffmahpop_params as dpp
from .. import mc_diffmahpop_kernels as mdk

N_T_PER_BIN = 5
LGSMAH_MIN = -15


@jjit
def mc_diffmah_preds(diffmahpop_u_params, pred_data):
diffmahpop_params = dpp.get_diffmahpop_params_from_u_params(diffmahpop_u_params)
tarr, lgm_obs, t_obs, ran_key, lgt0 = pred_data
_res = mdk._mc_diffmah_halo_sample(
diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0
)
ftpt0 = _res[3]
dmhdt_tpt0 = _res[5]
log_mah_tpt0 = _res[6]
dmhdt_tp = _res[7]
log_mah_tp = _res[8]
dmhdt_tpt0 = jnp.clip(dmhdt_tpt0, 10**LGSMAH_MIN)
dmhdt_tp = jnp.clip(dmhdt_tp, 10**LGSMAH_MIN)
lgsmah_tpt0 = jnp.log10(dmhdt_tpt0) - log_mah_tpt0
lgsmah_tp = jnp.log10(dmhdt_tp) - log_mah_tpt0
return lgsmah_tpt0, log_mah_tpt0, lgsmah_tp, log_mah_tp, ftpt0


@jjit
def get_single_sample_self_fit_target_data(
u_params, tarr, lgm_obs, t_obs, ran_key, lgt0
):
pred_data = tarr, lgm_obs, t_obs, ran_key, lgt0
_res = mc_diffmah_preds(u_params, pred_data)
lgsmah_tpt0, log_mah_tpt0, lgsmah_tp, log_mah_tp, ftpt0 = _res
weights_target = jnp.concatenate((ftpt0, 1 - ftpt0))
lgsmah_target = jnp.concatenate((lgsmah_tpt0, lgsmah_tp))
log_mahs_target = jnp.concatenate((log_mah_tpt0, log_mah_tp))
X_target = jnp.array((lgsmah_target, log_mahs_target)).swapaxes(0, 1)
return X_target, weights_target


@jjit
def single_sample_kde_loss_self_fit(
diffmahpop_u_params,
tarr,
lgm_obs,
t_obs,
ran_key,
lgt0,
X_target,
weights_target,
):
kcalc0 = kdescent.KCalc(X_target[:, :, 0], weights_target)
kcalc1 = kdescent.KCalc(X_target[:, :, 1], weights_target)
kcalc2 = kdescent.KCalc(X_target[:, :, 2], weights_target)
kcalc3 = kdescent.KCalc(X_target[:, :, 3], weights_target)
kcalc4 = kdescent.KCalc(X_target[:, :, 4], weights_target)

ran_key, pred_key = jran.split(ran_key, 2)
pred_data = tarr, lgm_obs, t_obs, pred_key, lgt0
_res = mc_diffmah_preds(diffmahpop_u_params, pred_data)
dmhdt_tpt0, log_mah_tpt0, dmhdt_tp, log_mah_tp, ftpt0 = _res

weights_pred = jnp.concatenate((ftpt0, 1 - ftpt0))
dmhdts_pred = jnp.concatenate((dmhdt_tpt0, dmhdt_tp))
log_mahs_pred = jnp.concatenate((log_mah_tpt0, log_mah_tp))
X_preds = jnp.array((dmhdts_pred, log_mahs_pred)).swapaxes(0, 1)

kcalc_keys = jran.split(ran_key, N_T_PER_BIN)

model_counts0, truth_counts0 = kcalc0.compare_kde_counts(
kcalc_keys[0], X_preds[:, :, 0], weights_pred
)
model_counts1, truth_counts1 = kcalc1.compare_kde_counts(
kcalc_keys[1], X_preds[:, :, 1], weights_pred
)
model_counts2, truth_counts2 = kcalc2.compare_kde_counts(
kcalc_keys[2], X_preds[:, :, 2], weights_pred
)
model_counts3, truth_counts3 = kcalc3.compare_kde_counts(
kcalc_keys[3], X_preds[:, :, 3], weights_pred
)
model_counts4, truth_counts4 = kcalc4.compare_kde_counts(
kcalc_keys[4], X_preds[:, :, 4], weights_pred
)

diff0 = model_counts0 - truth_counts0
diff1 = model_counts1 - truth_counts1
diff2 = model_counts2 - truth_counts2
diff3 = model_counts3 - truth_counts3
diff4 = model_counts4 - truth_counts4

loss0 = jnp.mean(diff0**2)
loss1 = jnp.mean(diff1**2)
loss2 = jnp.mean(diff2**2)
loss3 = jnp.mean(diff3**2)
loss4 = jnp.mean(diff4**2)

loss = loss0 + loss1 + loss2 + loss3 + loss4
return loss


single_sample_kde_loss_and_grad_self_fit = jjit(
value_and_grad(single_sample_kde_loss_self_fit)
)

_A = (None, 0, 0, 0, 0, None)
get_multisample_self_fit_target_data = jjit(
vmap(get_single_sample_self_fit_target_data, in_axes=_A)
)

_L = (None, 0, 0, 0, 0, None, 0, 0)
_multisample_kde_loss_self_fit = jjit(vmap(single_sample_kde_loss_self_fit, in_axes=_L))


@jjit
def multisample_kde_loss_self_fit(
diffmahpop_u_params,
tarr_matrix,
lgmobsarr,
tobsarr,
ran_keys,
lgt0,
X_targets,
weights_targets,
):
losses = _multisample_kde_loss_self_fit(
diffmahpop_u_params,
tarr_matrix,
lgmobsarr,
tobsarr,
ran_keys,
lgt0,
X_targets,
weights_targets,
)
loss = jnp.mean(losses)
return loss


multisample_kde_loss_and_grad_self_fit = jjit(
value_and_grad(multisample_kde_loss_self_fit)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
"""

import numpy as np
import pytest
from jax import jit as jjit
from jax import numpy as jnp
from jax import random as jran

from ... import diffmahpop_params as dpp
from .. import kde2d_wrappers as k2w

try:
import kdescent # noqa

HAS_KDESCENT = True
except ImportError:
HAS_KDESCENT = False

T_MIN_FIT = 0.5


def test_mc_diffmah_preds():
ran_key = jran.key(0)
t_0 = 13.8
lgt0 = np.log10(t_0)
n_times = 5

n_tests = 20
for __ in range(n_tests):
ran_key, m_key, t_key = jran.split(ran_key, 3)
lgm_obs = jran.uniform(m_key, minval=9, maxval=16, shape=())
t_obs = jran.uniform(t_key, minval=3, maxval=t_0, shape=())
tarr = np.linspace(T_MIN_FIT, t_obs, n_times)
pred_data = tarr, lgm_obs, t_obs, ran_key, lgt0
_preds = k2w.mc_diffmah_preds(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, pred_data)
for _x in _preds:
assert np.all(np.isfinite(_x))
lgsmar_tpt0, log_mah_tpt0, lgsmar_tp, log_mah_tp, ftpt0 = _preds
assert np.all(lgsmar_tpt0 < 20)
assert np.all(lgsmar_tp < 20)
assert np.all(log_mah_tpt0 < 20)
assert np.all(log_mah_tp < 20)
assert np.all(ftpt0 <= 1)
assert np.all(ftpt0 >= 0)


@pytest.mark.skipif("not HAS_KDESCENT")
def test_single_sample_kde_loss_self_fit():
"""Enforce that single-sample loss has finite grads"""
ran_key = jran.key(0)

t_0 = 13.8
lgt0 = np.log10(t_0)
DP = 0.5

n_tests = 100
for __ in range(n_tests):

# Use random diffmahpop parameter to generate fiducial data
u_p_fid_key, ran_key = jran.split(ran_key, 2)
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,))
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)]
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list))

ran_key, lgm_key, t_obs_key = jran.split(ran_key, 3)
lgm_obs = jran.uniform(lgm_key, minval=11, maxval=15, shape=())
t_obs = jran.uniform(t_obs_key, minval=4, maxval=t_0, shape=())

tarr = np.linspace(T_MIN_FIT, t_obs, k2w.N_T_PER_BIN)

_res = k2w.get_single_sample_self_fit_target_data(
u_p_fid, tarr, lgm_obs, t_obs, ran_key, lgt0
)
X_target, weights_target = _res

# use default params as the initial guess
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(
dpp.DEFAULT_DIFFMAHPOP_U_PARAMS
)

loss_data = tarr, lgm_obs, t_obs, ran_key, lgt0, X_target, weights_target
loss = k2w.single_sample_kde_loss_self_fit(u_p_init, *loss_data)
assert np.all(np.isfinite(loss)), (lgm_obs, t_obs)
assert loss > 0

loss, grads = k2w.single_sample_kde_loss_and_grad_self_fit(u_p_init, *loss_data)
assert np.all(np.isfinite(grads)), (lgm_obs, t_obs)


@pytest.mark.skipif("not HAS_KDESCENT")
def test_multisample_kde_loss_self_fit():
"""Enforce that multi-sample loss has finite grads"""
ran_key = jran.key(0)

# Use random diffmahpop parameter to generate fiducial data
DP = 0.1
u_p_fid_key, ran_key = jran.split(ran_key, 2)
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,))
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)]
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list))

lgmobs_key, tobs_key, ran_key = jran.split(ran_key, 3)
num_samples = 5
t_0 = 13.8
lgt0 = np.log10(t_0)
lgmobsarr = jran.uniform(lgmobs_key, minval=11, maxval=15, shape=(num_samples,))
tobsarr = jran.uniform(tobs_key, minval=4, maxval=13, shape=(num_samples,))
num_target_redshifts_per_t_obs = 10

tarr_matrix = jnp.array(
[jnp.linspace(T_MIN_FIT, t, num_target_redshifts_per_t_obs) for t in tobsarr]
)
_keys = jran.split(ran_key, num_samples * 2)
_res = k2w.get_multisample_self_fit_target_data(
u_p_fid, tarr_matrix, lgmobsarr, tobsarr, _keys[:num_samples], lgt0
)
X_targets, weights_targets = _res
loss_data = (
tarr_matrix,
lgmobsarr,
tobsarr,
_keys[num_samples:],
lgt0,
X_targets,
weights_targets,
)
# use default params as the initial guess
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)
loss = k2w.multisample_kde_loss_self_fit(u_p_init, *loss_data)
assert np.all(np.isfinite(loss))
assert loss > 0

loss, grads = k2w.multisample_kde_loss_and_grad_self_fit(u_p_init, *loss_data)
assert np.all(np.isfinite(grads))


@pytest.mark.skipif("not HAS_KDESCENT")
def test_kdescent_adam_self_fit():
"""Enforce that kdescent.adam terminates without NaNs"""
ran_key = jran.key(0)

# Use random diffmahpop parameter to generate fiducial data
DP = 0.5
u_p_fid_key, ran_key = jran.split(ran_key, 2)
n_params = len(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)
uran = jran.uniform(u_p_fid_key, minval=-DP, maxval=DP, shape=(n_params,))
_u_p_list = [x + u for x, u in zip(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, uran)]
u_p_fid = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(jnp.array(_u_p_list))

lgmobs_key, tobs_key, ran_key = jran.split(ran_key, 3)
num_samples = 5
t_0 = 13.8
lgt0 = np.log10(t_0)
lgmobsarr = jran.uniform(lgmobs_key, minval=11, maxval=15, shape=(num_samples,))
tobsarr = jran.uniform(tobs_key, minval=4, maxval=13, shape=(num_samples,))
num_target_redshifts_per_t_obs = 10

tarr_matrix = jnp.array(
[jnp.linspace(T_MIN_FIT, t, num_target_redshifts_per_t_obs) for t in tobsarr]
)

@jjit
def kde_loss(u_p, randkey):
_keys = jran.split(randkey, num_samples * 2)
_res = k2w.get_multisample_self_fit_target_data(
u_p_fid, tarr_matrix, lgmobsarr, tobsarr, _keys[:num_samples], lgt0
)
X_targets, weights_targets = _res
loss_data = (
tarr_matrix,
lgmobsarr,
tobsarr,
_keys[num_samples:],
lgt0,
X_targets,
weights_targets,
)
return k2w.multisample_kde_loss_self_fit(u_p, *loss_data)

# use default params as the initial guess
u_p_init = dpp.DEFAULT_DIFFMAHPOP_U_PARAMS._make(dpp.DEFAULT_DIFFMAHPOP_U_PARAMS)

adam_results = kdescent.adam(
kde_loss,
u_p_init,
nsteps=5,
learning_rate=0.1,
randkey=12345,
)
u_p_best = adam_results[-1]
assert np.all(np.isfinite(u_p_best))

0 comments on commit ba32dde

Please sign in to comment.