Skip to content

Commit 20f1116

Browse files
blethamfacebook-github-bot
authored andcommitted
optimize sequence of acquisition functions (#2931)
Summary: Pull Request resolved: #2931 Enables sequential q-batch optimization using a sequence of acquisition functions rather than the same acquisition function for each point in the batch. Right now just for when using optimize_acqf, thus focused on continuous search spaces. Reviewed By: sdaulton Differential Revision: D78560867 fbshipit-source-id: cbc0033ef904db63efee0f303c03f7901d49db69
1 parent 22bebf3 commit 20f1116

File tree

3 files changed

+151
-18
lines changed

3 files changed

+151
-18
lines changed

botorch/optim/optimize.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class OptimizeAcqfInputs:
6565
See docstring for `optimize_acqf` for explanation of parameters.
6666
"""
6767

68-
acq_function: AcquisitionFunction
68+
acq_function: AcquisitionFunction | None
6969
bounds: Tensor
7070
q: int
7171
num_restarts: int
@@ -85,6 +85,7 @@ class OptimizeAcqfInputs:
8585
return_full_tree: bool = False
8686
retry_on_optimization_warning: bool = True
8787
ic_gen_kwargs: dict = dataclasses.field(default_factory=dict)
88+
acq_function_sequence: list[AcquisitionFunction] | None = None
8889

8990
@property
9091
def full_tree(self) -> bool:
@@ -93,6 +94,10 @@ def full_tree(self) -> bool:
9394
)
9495

9596
def __post_init__(self) -> None:
97+
if self.acq_function is None and self.acq_function_sequence is None:
98+
raise ValueError(
99+
"Either `acq_function` or `acq_function_sequence` must be specified."
100+
)
96101
if self.inequality_constraints is None and not (
97102
self.bounds.ndim == 2 and self.bounds.shape[0] == 2
98103
):
@@ -168,6 +173,16 @@ def __post_init__(self) -> None:
168173
):
169174
raise ValueError("All indices (keys) in `fixed_features` must be >= 0.")
170175

176+
if self.acq_function_sequence is not None:
177+
if not self.sequential:
178+
raise ValueError(
179+
"acq_function_sequence requires sequential optimization."
180+
)
181+
if len(self.acq_function_sequence) != self.q:
182+
raise ValueError("acq_function_sequence must have length q.")
183+
if self.q < 2:
184+
raise ValueError("acq_function_sequence requires q > 1.")
185+
171186
def get_ic_generator(self) -> TGenInitialConditions:
172187
if self.ic_generator is not None:
173188
return self.ic_generator
@@ -264,29 +279,47 @@ def _optimize_acqf_sequential_q(
264279
else None
265280
)
266281
candidate_list, acq_value_list = [], []
267-
base_X_pending = opt_inputs.acq_function.X_pending
282+
if opt_inputs.acq_function_sequence is None:
283+
acq_function_sequence = [opt_inputs.acq_function]
284+
else:
285+
acq_function_sequence = opt_inputs.acq_function_sequence
286+
base_X_pending = [acqf.X_pending for acqf in acq_function_sequence]
287+
n_acq = len(acq_function_sequence)
288+
289+
new_kwargs = {
290+
"q": 1,
291+
"batch_initial_conditions": None,
292+
"return_best_only": True,
293+
"sequential": False,
294+
"timeout_sec": timeout_sec,
295+
"acq_function_sequence": None,
296+
}
297+
new_inputs = dataclasses.replace(opt_inputs, **new_kwargs)
268298

269-
new_inputs = dataclasses.replace(
270-
opt_inputs,
271-
q=1,
272-
batch_initial_conditions=None,
273-
return_best_only=True,
274-
sequential=False,
275-
timeout_sec=timeout_sec,
276-
)
277299
for i in range(opt_inputs.q):
300+
if n_acq > 1:
301+
acq_function = acq_function_sequence[i]
302+
new_kwargs["acq_function"] = acq_function
303+
new_inputs = dataclasses.replace(opt_inputs, **new_kwargs)
304+
if len(candidate_list) > 0:
305+
candidates = torch.cat(candidate_list, dim=-2)
306+
new_inputs.acq_function.set_X_pending(
307+
torch.cat([base_X_pending[i % n_acq], candidates], dim=-2)
308+
if base_X_pending[i % n_acq] is not None
309+
else candidates
310+
)
278311
candidate, acq_value = _optimize_acqf_batch(new_inputs)
279312

280313
candidate_list.append(candidate)
281314
acq_value_list.append(acq_value)
282-
candidates = torch.cat(candidate_list, dim=-2)
283-
new_inputs.acq_function.set_X_pending(
284-
torch.cat([base_X_pending, candidates], dim=-2)
285-
if base_X_pending is not None
286-
else candidates
287-
)
315+
288316
logger.info(f"Generated sequential candidate {i + 1} of {opt_inputs.q}")
289-
opt_inputs.acq_function.set_X_pending(base_X_pending)
317+
model_name = type(new_inputs.acq_function.model).__name__
318+
logger.debug(f"Used model {model_name} for candidate generation.")
319+
candidates = torch.cat(candidate_list, dim=-2)
320+
# Re-set X_pendings on the acquisitions to base values
321+
for acqf, X_pending in zip(acq_function_sequence, base_X_pending):
322+
acqf.set_X_pending(X_pending)
290323
return candidates, torch.stack(acq_value_list)
291324

292325

@@ -517,7 +550,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
517550

518551

519552
def optimize_acqf(
520-
acq_function: AcquisitionFunction,
553+
acq_function: AcquisitionFunction | None,
521554
bounds: Tensor,
522555
q: int,
523556
num_restarts: int,
@@ -532,6 +565,7 @@ def optimize_acqf(
532565
return_best_only: bool = True,
533566
gen_candidates: TGenCandidates | None = None,
534567
sequential: bool = False,
568+
acq_function_sequence: list[AcquisitionFunction] | None = None,
535569
*,
536570
ic_generator: TGenInitialConditions | None = None,
537571
timeout_sec: float | None = None,
@@ -627,6 +661,10 @@ def optimize_acqf(
627661
inputs. Default: `gen_candidates_scipy`
628662
sequential: If False, uses joint optimization, otherwise uses sequential
629663
optimization for optimizing multiple joint candidates (q > 1).
664+
acq_function_sequence: A list of acquisition functions to be optimized
665+
sequentially. Must be of length q>1, and requires sequential=True. Used
666+
for ensembling candidates from different acquisition functions. If
667+
omitted, use `acq_function` to generate all `q` candidates.
630668
ic_generator: Function for generating initial conditions. Not needed when
631669
`batch_initial_conditions` are provided. Defaults to
632670
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
@@ -689,6 +727,7 @@ def optimize_acqf(
689727
return_full_tree=return_full_tree,
690728
retry_on_optimization_warning=retry_on_optimization_warning,
691729
ic_gen_kwargs=ic_gen_kwargs,
730+
acq_function_sequence=acq_function_sequence,
692731
)
693732
return _optimize_acqf(opt_inputs=opt_acqf_inputs)
694733

botorch/utils/testing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,11 +674,14 @@ def __init__(self): # noqa: D107
674674
"""
675675
self.model = None
676676
self.X_pending = None
677+
self._call_args = {"__call__": [], "set_X_pending": []}
677678

678679
def __call__(self, X):
680+
self._call_args["__call__"].append(X)
679681
return X[..., 0].max(dim=-1).values
680682

681683
def set_X_pending(self, X_pending: Tensor | None = None):
684+
self._call_args["set_X_pending"].append(X_pending)
682685
self.X_pending = X_pending
683686

684687

test/optim/test_optimize.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,97 @@ def test_optimize_acqf_sequential(
449449
sequential=True,
450450
)
451451

452+
@mock.patch(
453+
"botorch.optim.optimize.gen_candidates_scipy", wraps=gen_candidates_scipy
454+
)
455+
def test_optimize_acq_function_sequence(
456+
self,
457+
mock_gen_candidates_scipy,
458+
):
459+
acq_function_sequence = [MockAcquisitionFunction() for _ in range(3)]
460+
bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
461+
# Validation
462+
with self.assertRaisesRegex(
463+
ValueError,
464+
"Either `acq_function` or `acq_function_sequence` must be specified",
465+
):
466+
optimize_acqf(
467+
acq_function=None,
468+
bounds=bounds,
469+
q=3,
470+
num_restarts=2,
471+
raw_samples=10,
472+
sequential=True,
473+
acq_function_sequence=None,
474+
)
475+
with self.assertRaisesRegex(
476+
ValueError,
477+
"acq_function_sequence requires sequential optimization",
478+
):
479+
optimize_acqf(
480+
acq_function=mock.MagicMock(),
481+
bounds=bounds,
482+
q=3,
483+
num_restarts=2,
484+
raw_samples=10,
485+
sequential=False,
486+
acq_function_sequence=acq_function_sequence,
487+
)
488+
with self.assertRaisesRegex(
489+
ValueError,
490+
"acq_function_sequence must have length q",
491+
):
492+
optimize_acqf(
493+
acq_function=mock.MagicMock(),
494+
bounds=bounds,
495+
q=2,
496+
num_restarts=2,
497+
raw_samples=10,
498+
sequential=True,
499+
acq_function_sequence=acq_function_sequence,
500+
)
501+
with self.assertRaisesRegex(
502+
ValueError,
503+
"acq_function_sequence requires q > 1",
504+
):
505+
optimize_acqf(
506+
acq_function=mock.MagicMock(),
507+
bounds=bounds,
508+
q=1,
509+
num_restarts=2,
510+
raw_samples=10,
511+
sequential=True,
512+
acq_function_sequence=acq_function_sequence[:1],
513+
)
514+
# Test that uses sequence of acquisitions
515+
acq_function = mock.MagicMock()
516+
acq_function.X_pending = None
517+
acq_function_sequence[2].X_pending = torch.ones(2, 3)
518+
_ = optimize_acqf(
519+
acq_function=acq_function,
520+
bounds=bounds,
521+
q=3,
522+
num_restarts=2,
523+
raw_samples=10,
524+
sequential=True,
525+
acq_function_sequence=acq_function_sequence,
526+
)
527+
self.assertEqual(mock_gen_candidates_scipy.call_count, 3)
528+
self.assertEqual(acq_function_sequence[0]._call_args["set_X_pending"], [None])
529+
for i in range(1, 2):
530+
set_X_args = acq_function_sequence[i]._call_args["set_X_pending"]
531+
self.assertEqual(len(set_X_args), 2)
532+
self.assertEqual(len(set_X_args[0]), i)
533+
self.assertIsNone(set_X_args[1])
534+
set_X_args = acq_function_sequence[2]._call_args["set_X_pending"]
535+
self.assertEqual(len(set_X_args), 2)
536+
self.assertEqual(len(set_X_args[0]), 4)
537+
self.assertTrue(
538+
torch.equal(set_X_args[0][:2, :], torch.ones(2, 3))
539+
) # base X_pending
540+
self.assertTrue(torch.equal(set_X_args[1], torch.ones(2, 3))) # reset
541+
acq_function.assert_not_called()
542+
452543
@mock.patch(
453544
"botorch.generation.gen.minimize_with_timeout",
454545
wraps=minimize_with_timeout,

0 commit comments

Comments
 (0)