Skip to content

Commit

Permalink
Merge pull request #105 from ArgonneCPAC/jax_adam_wrapper_tol
Browse files Browse the repository at this point in the history
Add optional kwarg tol to jax_adam_wrapper
  • Loading branch information
aphearin committed Jul 11, 2023
2 parents 61ea474 + a915618 commit 9a8d46f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
36 changes: 36 additions & 0 deletions diffmah/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,39 @@ def mse_loss_and_grad(params, data):

params_correct = [3, 1]
assert np.allclose(params_bestfit, params_correct, atol=0.01)


def test_jax_adam_wrapper_loss_tol_feature_works():
@jax_jit
def mse_loss(params, data):
x = data[0]
target = 3 * (x - 1)
a, b = params
pred = a * (x - b)
diff = pred - target
loss = jax_np.sum(diff * diff) / diff.size
return loss

@jax_jit
def mse_loss_and_grad(params, data):
return value_and_grad(mse_loss, argnums=0)(params, data)

params_init = np.array((2.75, 0.75))
x = np.linspace(-1, 1, 50)

data = (x,)
loss_init = mse_loss(params_init, data)
n_step = 100
params_bestfit, loss_bestfit0, loss_arr, params_arr, flag = jax_adam_wrapper(
mse_loss_and_grad, params_init, data, n_step, step_size=0.01, tol=1e-2
)
params_bestfit, loss_bestfit1, loss_arr, params_arr, flag = jax_adam_wrapper(
mse_loss_and_grad, params_init, data, n_step, step_size=0.01, tol=1e-3
)
params_bestfit, loss_bestfit2, loss_arr, params_arr, flag = jax_adam_wrapper(
mse_loss_and_grad, params_init, data, n_step, step_size=0.01, tol=1e-4
)
assert loss_bestfit0 <= 1e-2
assert loss_bestfit1 <= 1e-3
assert loss_bestfit2 <= 1e-4
assert loss_bestfit2 < loss_bestfit1 < loss_bestfit0 < loss_init
41 changes: 27 additions & 14 deletions diffmah/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def jax_adam_wrapper(
step_size=0.01,
warmup_n_step=50,
warmup_step_size=None,
tol=0.0,
):
"""Convenience function wrapping JAX's Adam optimizer used to
minimize the loss function loss_func.
Expand Down Expand Up @@ -139,32 +140,38 @@ def jax_adam_wrapper(
warmup_step_size = 5 * step_size

p_init = np.copy(params_init)
loss_init = float("inf")
for i in range(n_warmup):
p_init = _jax_adam_wrapper(
fit_results = _jax_adam_wrapper(
loss_and_grad_func,
p_init,
loss_data,
warmup_n_step,
step_size=warmup_step_size,
)[0]
tol=tol,
)
p_init = fit_results[0]
loss_init = fit_results[1]

if np.all(np.isfinite(p_init)):
p0 = p_init
else:
p0 = params_init

_res = _jax_adam_wrapper(
loss_and_grad_func, p0, loss_data, n_step, step_size=step_size
)
if len(_res[2]) < n_step:
if loss_init > tol:
fit_results = _jax_adam_wrapper(
loss_and_grad_func, p0, loss_data, n_step, step_size=step_size, tol=tol
)

if len(fit_results[2]) < n_step:
fit_terminates = 0
else:
fit_terminates = 1
return (*_res, fit_terminates)
return (*fit_results, fit_terminates)


def _jax_adam_wrapper(
loss_and_grad_func, params_init, loss_data, n_step, step_size=0.01
loss_and_grad_func, params_init, loss_data, n_step, step_size=0.01, tol=0.0
):
"""Convenience function wrapping JAX's Adam optimizer used to
minimize the loss function loss_func.
Expand Down Expand Up @@ -207,7 +214,7 @@ def _jax_adam_wrapper(
Stores the value of the model params at each step
"""
loss_arr = np.zeros(n_step).astype("f4") - 1.0
loss_collector = []
opt_init, opt_update, get_params = jax_opt.adam(step_size)
opt_state = opt_init(params_init)
n_params = len(params_init)
Expand All @@ -217,29 +224,35 @@ def _jax_adam_wrapper(
p = np.array(get_params(opt_state))

loss, grads = loss_and_grad_func(p, loss_data)
loss_collector.append(loss)

no_nan_params = np.all(np.isfinite(p))
no_nan_loss = np.isfinite(loss)
no_nan_grads = np.all(np.isfinite(grads))
if ~no_nan_params | ~no_nan_loss | ~no_nan_grads:
has_nans = ~no_nan_params | ~no_nan_loss | ~no_nan_grads
if has_nans:
if istep > 0:
indx_best = np.nanargmin(loss_arr[:istep])
indx_best = np.nanargmin(loss_collector[:istep])
best_fit_params = params_arr[indx_best]
best_fit_loss = loss_arr[indx_best]
best_fit_loss = loss_collector[indx_best]
else:
best_fit_params = np.copy(p)
best_fit_loss = 999.99
return (
best_fit_params,
best_fit_loss,
loss_arr[:istep],
np.array(loss_collector[:istep]),
params_arr[:istep, :],
)
else:
params_arr[istep, :] = p
loss_arr[istep] = loss
opt_state = opt_update(istep, grads, opt_state)
if loss < tol:
best_fit_params = p
loss_arr = np.array(loss_collector)
return best_fit_params, loss, loss_arr, params_arr

loss_arr = np.array(loss_collector)
indx_best = np.nanargmin(loss_arr)
best_fit_params = params_arr[indx_best]
loss = loss_arr[indx_best]
Expand Down

0 comments on commit 9a8d46f

Please sign in to comment.