Skip to content

optimize sequence of acquisition functions #2931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class OptimizeAcqfInputs:
See docstring for `optimize_acqf` for explanation of parameters.
"""

acq_function: AcquisitionFunction
acq_function: AcquisitionFunction | None
bounds: Tensor
q: int
num_restarts: int
Expand All @@ -85,6 +85,7 @@ class OptimizeAcqfInputs:
return_full_tree: bool = False
retry_on_optimization_warning: bool = True
ic_gen_kwargs: dict = dataclasses.field(default_factory=dict)
acq_function_sequence: list[AcquisitionFunction] | None = None

@property
def full_tree(self) -> bool:
Expand All @@ -93,6 +94,10 @@ def full_tree(self) -> bool:
)

def __post_init__(self) -> None:
if self.acq_function is None and self.acq_function_sequence is None:
raise ValueError(
"Either `acq_function` or `acq_function_sequence` must be specified."
)
if self.inequality_constraints is None and not (
self.bounds.ndim == 2 and self.bounds.shape[0] == 2
):
Expand Down Expand Up @@ -168,6 +173,16 @@ def __post_init__(self) -> None:
):
raise ValueError("All indices (keys) in `fixed_features` must be >= 0.")

if self.acq_function_sequence is not None:
if not self.sequential:
raise ValueError(
"acq_function_sequence requires sequential optimization."
)
if len(self.acq_function_sequence) != self.q:
raise ValueError("acq_function_sequence must have length q.")
if self.q < 2:
raise ValueError("acq_function_sequence requires q > 1.")

def get_ic_generator(self) -> TGenInitialConditions:
if self.ic_generator is not None:
return self.ic_generator
Expand Down Expand Up @@ -264,29 +279,47 @@ def _optimize_acqf_sequential_q(
else None
)
candidate_list, acq_value_list = [], []
base_X_pending = opt_inputs.acq_function.X_pending
if opt_inputs.acq_function_sequence is None:
acq_function_sequence = [opt_inputs.acq_function]
else:
acq_function_sequence = opt_inputs.acq_function_sequence
base_X_pending = [acqf.X_pending for acqf in acq_function_sequence]
n_acq = len(acq_function_sequence)

new_kwargs = {
"q": 1,
"batch_initial_conditions": None,
"return_best_only": True,
"sequential": False,
"timeout_sec": timeout_sec,
"acq_function_sequence": None,
}
new_inputs = dataclasses.replace(opt_inputs, **new_kwargs)

new_inputs = dataclasses.replace(
opt_inputs,
q=1,
batch_initial_conditions=None,
return_best_only=True,
sequential=False,
timeout_sec=timeout_sec,
)
for i in range(opt_inputs.q):
if n_acq > 1:
acq_function = acq_function_sequence[i]
new_kwargs["acq_function"] = acq_function
new_inputs = dataclasses.replace(opt_inputs, **new_kwargs)
if len(candidate_list) > 0:
candidates = torch.cat(candidate_list, dim=-2)
new_inputs.acq_function.set_X_pending(
torch.cat([base_X_pending[i % n_acq], candidates], dim=-2)
if base_X_pending[i % n_acq] is not None
else candidates
)
candidate, acq_value = _optimize_acqf_batch(new_inputs)

candidate_list.append(candidate)
acq_value_list.append(acq_value)
candidates = torch.cat(candidate_list, dim=-2)
new_inputs.acq_function.set_X_pending(
torch.cat([base_X_pending, candidates], dim=-2)
if base_X_pending is not None
else candidates
)

logger.info(f"Generated sequential candidate {i + 1} of {opt_inputs.q}")
opt_inputs.acq_function.set_X_pending(base_X_pending)
model_name = type(new_inputs.acq_function.model).__name__
logger.debug(f"Used model {model_name} for candidate generation.")
candidates = torch.cat(candidate_list, dim=-2)
# Re-set X_pendings on the acquisitions to base values
for acqf, X_pending in zip(acq_function_sequence, base_X_pending):
acqf.set_X_pending(X_pending)
return candidates, torch.stack(acq_value_list)


Expand Down Expand Up @@ -517,7 +550,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:


def optimize_acqf(
acq_function: AcquisitionFunction,
acq_function: AcquisitionFunction | None,
bounds: Tensor,
q: int,
num_restarts: int,
Expand All @@ -532,6 +565,7 @@ def optimize_acqf(
return_best_only: bool = True,
gen_candidates: TGenCandidates | None = None,
sequential: bool = False,
acq_function_sequence: list[AcquisitionFunction] | None = None,
*,
ic_generator: TGenInitialConditions | None = None,
timeout_sec: float | None = None,
Expand Down Expand Up @@ -627,6 +661,10 @@ def optimize_acqf(
inputs. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization for optimizing multiple joint candidates (q > 1).
acq_function_sequence: A list of acquisition functions to be optimized
sequentially. Must be of length q>1, and requires sequential=True. Used
for ensembling candidates from different acquisition functions. If
omitted, use `acq_function` to generate all `q` candidates.
ic_generator: Function for generating initial conditions. Not needed when
`batch_initial_conditions` are provided. Defaults to
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
Expand Down Expand Up @@ -689,6 +727,7 @@ def optimize_acqf(
return_full_tree=return_full_tree,
retry_on_optimization_warning=retry_on_optimization_warning,
ic_gen_kwargs=ic_gen_kwargs,
acq_function_sequence=acq_function_sequence,
)
return _optimize_acqf(opt_inputs=opt_acqf_inputs)

Expand Down
3 changes: 3 additions & 0 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,11 +674,14 @@ def __init__(self): # noqa: D107
"""
self.model = None
self.X_pending = None
self._call_args = {"__call__": [], "set_X_pending": []}

def __call__(self, X):
self._call_args["__call__"].append(X)
return X[..., 0].max(dim=-1).values

def set_X_pending(self, X_pending: Tensor | None = None):
self._call_args["set_X_pending"].append(X_pending)
self.X_pending = X_pending


Expand Down
91 changes: 91 additions & 0 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,97 @@ def test_optimize_acqf_sequential(
sequential=True,
)

@mock.patch(
"botorch.optim.optimize.gen_candidates_scipy", wraps=gen_candidates_scipy
)
def test_optimize_acq_function_sequence(
self,
mock_gen_candidates_scipy,
):
acq_function_sequence = [MockAcquisitionFunction() for _ in range(3)]
bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
# Validation
with self.assertRaisesRegex(
ValueError,
"Either `acq_function` or `acq_function_sequence` must be specified",
):
optimize_acqf(
acq_function=None,
bounds=bounds,
q=3,
num_restarts=2,
raw_samples=10,
sequential=True,
acq_function_sequence=None,
)
with self.assertRaisesRegex(
ValueError,
"acq_function_sequence requires sequential optimization",
):
optimize_acqf(
acq_function=mock.MagicMock(),
bounds=bounds,
q=3,
num_restarts=2,
raw_samples=10,
sequential=False,
acq_function_sequence=acq_function_sequence,
)
with self.assertRaisesRegex(
ValueError,
"acq_function_sequence must have length q",
):
optimize_acqf(
acq_function=mock.MagicMock(),
bounds=bounds,
q=2,
num_restarts=2,
raw_samples=10,
sequential=True,
acq_function_sequence=acq_function_sequence,
)
with self.assertRaisesRegex(
ValueError,
"acq_function_sequence requires q > 1",
):
optimize_acqf(
acq_function=mock.MagicMock(),
bounds=bounds,
q=1,
num_restarts=2,
raw_samples=10,
sequential=True,
acq_function_sequence=acq_function_sequence[:1],
)
# Test that uses sequence of acquisitions
acq_function = mock.MagicMock()
acq_function.X_pending = None
acq_function_sequence[2].X_pending = torch.ones(2, 3)
_ = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=3,
num_restarts=2,
raw_samples=10,
sequential=True,
acq_function_sequence=acq_function_sequence,
)
self.assertEqual(mock_gen_candidates_scipy.call_count, 3)
self.assertEqual(acq_function_sequence[0]._call_args["set_X_pending"], [None])
for i in range(1, 2):
set_X_args = acq_function_sequence[i]._call_args["set_X_pending"]
self.assertEqual(len(set_X_args), 2)
self.assertEqual(len(set_X_args[0]), i)
self.assertIsNone(set_X_args[1])
set_X_args = acq_function_sequence[2]._call_args["set_X_pending"]
self.assertEqual(len(set_X_args), 2)
self.assertEqual(len(set_X_args[0]), 4)
self.assertTrue(
torch.equal(set_X_args[0][:2, :], torch.ones(2, 3))
) # base X_pending
self.assertTrue(torch.equal(set_X_args[1], torch.ones(2, 3))) # reset
acq_function.assert_not_called()

@mock.patch(
"botorch.generation.gen.minimize_with_timeout",
wraps=minimize_with_timeout,
Expand Down
Loading