Skip to content

Commit

Permalink
Allow AskTellOptimizer that doesn't track datasets (secondmind-labs#834)
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored May 6, 2024
1 parent f5a83ec commit aabd293
Show file tree
Hide file tree
Showing 9 changed files with 487 additions and 53 deletions.
39 changes: 31 additions & 8 deletions tests/integration/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
import pickle
import tempfile
from dataclasses import replace
from typing import Callable, Mapping, Tuple, Union

import numpy.testing as npt
Expand All @@ -34,8 +35,8 @@
TREGOBox,
)
from trieste.acquisition.utils import copy_to_local_models
from trieste.ask_tell_optimization import AskTellOptimizer
from trieste.bayesian_optimizer import OptimizationResult, Record
from trieste.ask_tell_optimization import AskTellOptimizer, AskTellOptimizerState
from trieste.bayesian_optimizer import OptimizationResult
from trieste.data import Dataset
from trieste.logging import set_step_number, tensorboard_writer
from trieste.models import TrainableProbabilisticModel
Expand Down Expand Up @@ -150,16 +151,22 @@ def test_ask_tell_optimizer_finds_minima_of_the_scaled_branin_function(


@random_seed
@pytest.mark.parametrize("track_data", [True, False])
@pytest.mark.parametrize(*copy.deepcopy(OPTIMIZER_PARAMS))
def test_ask_tell_optimizer_finds_minima_of_simple_quadratic(
track_data: bool,
num_steps: int,
reload_state: bool,
acquisition_rule_fn: AcquisitionRuleFunction | Tuple[AcquisitionRuleFunction, int],
) -> None:
# for speed reasons we sometimes test with a simple quadratic defined on the same search space
# branin; currently assume that every rule should be able to solve this in 5 steps
_test_ask_tell_optimization_finds_minima(
False, min(num_steps, 5), reload_state, acquisition_rule_fn
False,
min(num_steps, 5),
reload_state,
acquisition_rule_fn,
track_data=track_data,
)


Expand All @@ -168,6 +175,7 @@ def _test_ask_tell_optimization_finds_minima(
num_steps: int,
reload_state: bool,
acquisition_rule_fn: AcquisitionRuleFunction | Tuple[AcquisitionRuleFunction, int],
track_data: bool = True,
) -> None:
# For the case when optimization state is saved and reload on each iteration
# we need to use new acquisition function object to imitate real life usage
Expand Down Expand Up @@ -195,7 +203,7 @@ def _test_ask_tell_optimization_finds_minima(
with tensorboard_writer(summary_writer):
set_step_number(0)
ask_tell = AskTellOptimizer(
search_space, initial_dataset, models, acquisition_rule_fn()
search_space, initial_dataset, models, acquisition_rule_fn(), track_data=track_data
)

for i in range(1, num_steps + 1):
Expand All @@ -206,10 +214,10 @@ def _test_ask_tell_optimization_finds_minima(
new_point = ask_tell.ask()

if reload_state:
state: Record[
state: AskTellOptimizerState[
None | State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State],
GaussianProcessRegression,
] = ask_tell.to_record()
] = ask_tell.to_state()
written_state = pickle.dumps(state)

# If query points are rank 3, then use a batched observer.
Expand All @@ -222,11 +230,26 @@ def _test_ask_tell_optimization_finds_minima(

if reload_state:
state = pickle.loads(written_state)
state_record = state.record
if not track_data:
# reload using the up-to-date dataset
state_record = replace(state_record, datasets=initial_dataset)
ask_tell = AskTellOptimizer.from_record(
state, search_space, acquisition_rule_fn()
state_record,
search_space,
acquisition_rule_fn(),
track_data=track_data,
local_data_ixs=state.local_data_ixs,
)

ask_tell.tell(new_data_point)
if track_data:
ask_tell.tell(new_data_point)
else:
if isinstance(new_data_point, Dataset):
new_data_point = {OBJECTIVE: new_data_point}
for tag in initial_dataset.keys():
initial_dataset[tag] += new_data_point[tag]
ask_tell.tell(initial_dataset)

result: OptimizationResult[
None | State[TensorType, AsynchronousRuleState | BatchTrustRegionBox.State],
Expand Down
45 changes: 43 additions & 2 deletions tests/unit/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Mapping, Optional
from typing import Any, Mapping, Optional, Sequence
from unittest.mock import MagicMock

import numpy as np
Expand All @@ -31,7 +31,7 @@
)
from trieste.data import Dataset
from trieste.space import Box, SearchSpaceType
from trieste.types import Tag
from trieste.types import Tag, TensorType
from trieste.utils.misc import LocalizedTag


Expand Down Expand Up @@ -166,6 +166,47 @@ def test_with_local_datasets(
assert datasets[ltag] is original_datasets[global_tag]


@pytest.mark.parametrize(
"datasets, indices",
[
(
{
"a": Dataset(tf.constant([[1.0, 2.0], [3.0, 4.0]]), tf.constant([[5.0], [6.0]])),
"b": Dataset(tf.constant([[7.0, 8.0], [9.0, 1.0]]), tf.constant([[2.0], [3.0]])),
},
[tf.constant([0]), tf.constant([0, 1])],
),
(
{
"a": Dataset(tf.constant([[1.0, 2.0], [3.0, 4.0]]), tf.constant([[5.0], [6.0]])),
"b": Dataset(tf.constant([[7.0, 8.0], [9.0, 1.0]]), tf.constant([[2.0], [3.0]])),
},
[tf.constant([], dtype=tf.int32), tf.constant([0])],
),
],
)
def test_with_local_datasets_indices(
datasets: Mapping[Tag, Dataset], indices: Sequence[TensorType]
) -> None:
original_datasets = dict(datasets).copy()
global_tags = {t for t in original_datasets if not LocalizedTag.from_tag(t).is_local}
num_global_datasets = len(global_tags)

num_local_datasets = len(indices)
datasets = with_local_datasets(datasets, num_local_datasets, indices)
assert len(datasets) == num_global_datasets * (1 + num_local_datasets)

for global_tag in global_tags:
assert datasets[global_tag] is original_datasets[global_tag]
for i in range(num_local_datasets):
ltag = LocalizedTag(global_tag, i)
if ltag in original_datasets:
assert datasets[ltag] is original_datasets[ltag]
else:
assert len(datasets[ltag].query_points) == len(indices[i])
assert len(datasets[ltag].observations) == len(indices[i])


@pytest.mark.parametrize(
"points, tolerance, expected_mask",
[
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/models/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _model_stack() -> (

def test_model_stack_predict() -> None:
stack, (model01, model2, model3) = _model_stack()
assert all(
isinstance(model, TrainableProbabilisticModel) for model in (stack, model01, model2, model3)
)
query_points = tf.random.uniform([5, 7, 3])
mean, var = stack.predict(query_points)

Expand Down
Loading

0 comments on commit aabd293

Please sign in to comment.