Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/5/add mulit fidelity optimization #13

Merged
merged 9 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion boax/core/distributions/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Gamma(NamedTuple):

def gamma(a: Array, b: Array = jnp.ones(())) -> Gamma:
"""
Smart constructor for the beta distribution.
Smart constructor for the gamma distribution.

Args:
a: The shape parameter.
Expand Down
2 changes: 1 addition & 1 deletion boax/core/distributions/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Poisson(NamedTuple):

def poisson(mu: Array) -> Poisson:
"""
Smart constructor for the beta distribution.
Smart constructor for the poisson distribution.

Args:
mu: The rate parameter.
Expand Down
2 changes: 2 additions & 0 deletions boax/core/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@

from .alias import halton_normal as halton_normal
from .alias import halton_uniform as halton_uniform
from .alias import normal as normal
from .alias import uniform as uniform
from .base import Sampler as Sampler
56 changes: 53 additions & 3 deletions boax/core/samplers/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,65 @@
from boax.utils.functools import compose


def uniform(
uniform: Uniform = Uniform(jnp.zeros((1,)), jnp.ones((1,))),
) -> Sampler:
"""
The i.i.d. uniform sampler.

Example:
>>> sampler = uniform()
>>> base_samples = sampler(key, (128,))

Args:
uniform: The base uniform distribution.

Returns:
The corresponding `Sampler`.
"""

out_shape = lax.broadcast_shapes(uniform.a.shape, uniform.b.shape)

return compose(
partial(partial, distributions.uniform.sample)(uniform),
partial(functions.iid.uniform, ndims=out_shape[0]),
)


def normal(
normal: Normal = Normal(jnp.zeros((1,)), jnp.ones((1,))),
) -> Sampler:
"""
The i.i.d. normal sampler.

Example:
>>> sampler = normal()
>>> base_samples = sampler(key, (128,))

Args:
normal: The base normal distribution.

Returns:
The corresponding `Sampler`.
"""

out_shape = lax.broadcast_shapes(normal.loc.shape, normal.scale.shape)

return compose(
partial(partial, distributions.normal.sample)(normal),
partial(functions.iid.normal, ndims=out_shape[0]),
)


def halton_uniform(
uniform: Uniform = Uniform(jnp.zeros((1,)), jnp.ones((1,))),
) -> Sampler:
"""
The quasi-MC uniform sampler based on halton sequences.

Example:
>>> sampler = halton_uniform(uniform)
>>> base_samples = sampler(key, 128)
>>> sampler = halton_uniform()
>>> base_samples = sampler(key, (128,))

Args:
uniform: The base uniform distribution.
Expand Down Expand Up @@ -66,7 +116,7 @@ def halton_normal(
The quasi-MC normal sampler based on halton sequences.

Example:
>>> sampler = halton_normal(normal)
>>> sampler = halton_normal()
>>> base_samples = sampler(key, 128)

Args:
Expand Down
6 changes: 3 additions & 3 deletions boax/core/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Base interface for samplers."""

from typing import Protocol
from typing import Protocol, Sequence

from boax.utils.typing import Array, PRNGKey

Expand All @@ -27,13 +27,13 @@ class Sampler(Protocol):
and returns `num_results` samples.
"""

def __call__(self, key: PRNGKey, num_results: int) -> Array:
def __call__(self, key: PRNGKey, shape: Sequence[int]) -> Array:
"""
Draws `num_results` of samples.

Args:
key: The pseudo-random number generator key.
candidates: The number of results to return.
shape: The sample shape.

Returns:
A set of `num_results` samples.
Expand Down
1 change: 1 addition & 0 deletions boax/core/samplers/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

"""The sampler functions sub-package."""

from . import iid as iid
from . import quasi_random as quasi_random
from . import utils as utils
29 changes: 29 additions & 0 deletions boax/core/samplers/functions/iid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023 The Boax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""IID sampling functions."""

from typing import Sequence

from jax import random

from boax.utils.typing import Array, PRNGKey


def uniform(key: PRNGKey, sample_shape: Sequence[int], ndims: int) -> Array:
return random.uniform(key, sample_shape + (ndims,))


def normal(key: PRNGKey, sample_shape: Sequence[int], ndims: int) -> Array:
return random.normal(key, sample_shape + (ndims,))
12 changes: 9 additions & 3 deletions boax/core/samplers/functions/quasi_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Quasi Random sampling functions."""

from typing import Sequence

from jax import numpy as jnp
from jax import random

Expand All @@ -26,8 +28,11 @@
assert len(PRIMES) == MAX_DIMENSION


def halton_sequence(key: PRNGKey, num_samples: int, ndims: int) -> Array:
def halton_sequence(
key: PRNGKey, sample_shape: Sequence[int], ndims: int
) -> Array:
shuffle_key, correction_key = random.split(key)
num_samples = jnp.prod(jnp.asarray(sample_shape))

radixes = PRIMES[0:ndims][..., jnp.newaxis]
indices = jnp.reshape(jnp.arange(num_samples) + 1, (-1, 1, 1))
Expand All @@ -48,8 +53,9 @@ def halton_sequence(key: PRNGKey, num_samples: int, ndims: int) -> Array:
base_values = jnp.sum(shuffled / (radixes * weights), axis=-1)
zero_correction = random.uniform(correction_key, (ndims, 1))

return (
base_values + (zero_correction / (radixes**max_sizes_by_axes)).flatten()
return jnp.reshape(
base_values + (zero_correction / (radixes**max_sizes_by_axes)).flatten(),
sample_shape + (ndims,),
)


Expand Down
3 changes: 3 additions & 0 deletions boax/optimization/acquisitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from .alias import probability_of_improvement as probability_of_improvement
from .alias import q_expected_improvement as q_expected_improvement
from .alias import q_knowledge_gradient as q_knowledge_gradient
from .alias import (
q_multi_fidelity_knowledge_gradient as q_multi_fidelity_knowledge_gradient,
)
from .alias import q_probability_of_improvement as q_probability_of_improvement
from .alias import q_upper_confidence_bound as q_upper_confidence_bound
from .alias import upper_confidence_bound as upper_confidence_bound
Expand Down
70 changes: 53 additions & 17 deletions boax/optimization/acquisitions/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import math
from functools import partial
from operator import attrgetter
from operator import attrgetter, itemgetter
from typing import Callable, Tuple

from jax import jit, lax, scipy
from jax import numpy as jnp
Expand All @@ -25,7 +26,7 @@
from boax.core.distributions.normal import Normal
from boax.optimization.acquisitions import functions
from boax.optimization.acquisitions.base import Acquisition
from boax.utils.functools import compose
from boax.utils.functools import apply, compose, unwrap
from boax.utils.typing import Array, Numeric


Expand All @@ -41,7 +42,7 @@ def probability_of_improvement(

Example:
>>> acqf = probability_of_improvement(0.2)
>>> poi = acqf(model(xs))
>>> poi = acqf(vmap(model)(xs))

Args:
best: The best function value observed so far.
Expand Down Expand Up @@ -71,7 +72,7 @@ def log_probability_of_improvement(

Example:
>>> acqf = log_probability_of_improvement(0.2)
>>> log_poi = acqf(model(xs))
>>> log_poi = acqf(vmap(model)(xs))

Args:
best: The best function value observed so far.
Expand Down Expand Up @@ -103,7 +104,7 @@ def expected_improvement(

Example:
>>> acqf = expected_improvement(0.2)
>>> ei = acqf(model(xs))
>>> ei = acqf(vmap(model)(xs))

Args:
best: The best function value observed so far.
Expand Down Expand Up @@ -138,7 +139,7 @@ def log_expected_improvement(

Example:
>>> acqf = log_expected_improvement(0.2)
>>> log_ei = acqf(model(xs))
>>> log_ei = acqf(vmap(model)(xs))

Args:
best: The best function value observed so far.
Expand Down Expand Up @@ -168,7 +169,7 @@ def upper_confidence_bound(

Example:
>>> acqf = upper_confidence_bound(2.0)
>>> ucb = acqf(model(xs))
>>> ucb = acqf(vmap(model)(xs))

Args:
beta: The mean and covariance trade-off parameter.
Expand All @@ -191,7 +192,7 @@ def posterior_mean() -> Acquisition:

Example:
>>> acqf = posterior_mean()
>>> mean = acqf(model(xs))
>>> mean = acqf(vmap(model)(xs))

Args:
model: A gaussian process regression surrogate model.
Expand All @@ -214,7 +215,7 @@ def posterior_scale() -> Acquisition:

Example:
>>> acqf = posterior_scale()
>>> scale = acqf(model(xs))
>>> scale = acqf(vmap(model)(xs))

Args:
model: A gaussian process regression surrogate model.
Expand Down Expand Up @@ -248,7 +249,7 @@ def q_probability_of_improvement(

Example:
>>> acqf = q_probability_of_improvement(1.0, 0.2)
>>> qpoi = acqf(model(xs))
>>> qpoi = acqf(vmap(model)(xs))

Args:
tau: The temperature parameter.
Expand All @@ -260,7 +261,7 @@ def q_probability_of_improvement(

return jit(
compose(
partial(jnp.mean, axis=-1),
partial(jnp.mean, axis=0),
partial(jnp.amax, axis=-1),
partial(functions.monte_carlo.qpoi, best=best, tau=tau),
)
Expand All @@ -281,7 +282,7 @@ def q_expected_improvement(

Example:
>>> acqf = q_expected_improvement(0.2)
>>> qei = acqf(model(xs))
>>> qei = acqf(vmap(model)(xs))

Args:
best: The best function value observed so far.
Expand All @@ -292,7 +293,7 @@ def q_expected_improvement(

return jit(
compose(
partial(jnp.mean, axis=-1),
partial(jnp.mean, axis=0),
partial(jnp.amax, axis=-1),
partial(functions.monte_carlo.qei, best=best),
)
Expand All @@ -311,7 +312,7 @@ def q_upper_confidence_bound(

Example:
>>> acqf = q_upper_confidence_bound(2.0)
>>> qucb = acqf(model(xs))
>>> qucb = acqf(vmap(model)(xs))

Args:
beta: The mean and covariance trade-off parameter.
Expand All @@ -324,7 +325,7 @@ def q_upper_confidence_bound(

return jit(
compose(
partial(jnp.mean, axis=-1),
partial(jnp.mean, axis=0),
partial(jnp.amax, axis=-1),
partial(functions.monte_carlo.qucb, beta=beta_prime),
)
Expand All @@ -339,7 +340,7 @@ def q_knowledge_gradient(

Example:
>>> acqf = q_knowledge_gradient(0.2)
>>> qucb = acqf(model(xs))
>>> qkg = acqf(vmap(model)(xs))

Args:
best: The best function value observed so far.
Expand All @@ -350,9 +351,44 @@ def q_knowledge_gradient(

return jit(
compose(
partial(jnp.mean, axis=-1),
partial(jnp.mean, axis=0),
partial(jnp.squeeze, axis=-1),
partial(lax.sub, y=best),
attrgetter('loc'),
)
)


def q_multi_fidelity_knowledge_gradient(
best: Numeric,
cost_fn: Callable,
) -> Acquisition[Tuple[Normal, Array]]:
"""
MC-based batch multi-fidelity Knowledge Gradient acquisition function.

Example:
>>> acqf = q_knowledge_gradient(0.2, cost_fn)
>>> qmfkg = acqf(vmap(model)(xs))

Args:
best: The best function value observed so far.

Returns:
The corresponding `Acquisition`.
"""

return jit(
compose(
partial(jnp.mean, axis=0),
apply(
unwrap(cost_fn),
compose(
partial(jnp.squeeze, axis=-1),
partial(lax.sub, y=best),
attrgetter('loc'),
itemgetter(0),
),
itemgetter(1),
),
)
)
Loading
Loading