@@ -65,7 +65,7 @@ class OptimizeAcqfInputs:
65
65
See docstring for `optimize_acqf` for explanation of parameters.
66
66
"""
67
67
68
- acq_function : AcquisitionFunction
68
+ acq_function : AcquisitionFunction | None
69
69
bounds : Tensor
70
70
q : int
71
71
num_restarts : int
@@ -85,6 +85,7 @@ class OptimizeAcqfInputs:
85
85
return_full_tree : bool = False
86
86
retry_on_optimization_warning : bool = True
87
87
ic_gen_kwargs : dict = dataclasses .field (default_factory = dict )
88
+ acq_function_sequence : list [AcquisitionFunction ] | None = None
88
89
89
90
@property
90
91
def full_tree (self ) -> bool :
@@ -93,6 +94,10 @@ def full_tree(self) -> bool:
93
94
)
94
95
95
96
def __post_init__ (self ) -> None :
97
+ if self .acq_function is None and self .acq_function_sequence is None :
98
+ raise ValueError (
99
+ "Either `acq_function` or `acq_function_sequence` must be specified."
100
+ )
96
101
if self .inequality_constraints is None and not (
97
102
self .bounds .ndim == 2 and self .bounds .shape [0 ] == 2
98
103
):
@@ -168,6 +173,16 @@ def __post_init__(self) -> None:
168
173
):
169
174
raise ValueError ("All indices (keys) in `fixed_features` must be >= 0." )
170
175
176
+ if self .acq_function_sequence is not None :
177
+ if not self .sequential :
178
+ raise ValueError (
179
+ "acq_function_sequence requires sequential optimization."
180
+ )
181
+ if len (self .acq_function_sequence ) != self .q :
182
+ raise ValueError ("acq_function_sequence must have length q." )
183
+ if self .q < 2 :
184
+ raise ValueError ("acq_function_sequence requires q > 1." )
185
+
171
186
def get_ic_generator (self ) -> TGenInitialConditions :
172
187
if self .ic_generator is not None :
173
188
return self .ic_generator
@@ -264,29 +279,47 @@ def _optimize_acqf_sequential_q(
264
279
else None
265
280
)
266
281
candidate_list , acq_value_list = [], []
267
- base_X_pending = opt_inputs .acq_function .X_pending
282
+ if opt_inputs .acq_function_sequence is None :
283
+ acq_function_sequence = [opt_inputs .acq_function ]
284
+ else :
285
+ acq_function_sequence = opt_inputs .acq_function_sequence
286
+ base_X_pending = [acqf .X_pending for acqf in acq_function_sequence ]
287
+ n_acq = len (acq_function_sequence )
288
+
289
+ new_kwargs = {
290
+ "q" : 1 ,
291
+ "batch_initial_conditions" : None ,
292
+ "return_best_only" : True ,
293
+ "sequential" : False ,
294
+ "timeout_sec" : timeout_sec ,
295
+ "acq_function_sequence" : None ,
296
+ }
297
+ new_inputs = dataclasses .replace (opt_inputs , ** new_kwargs )
268
298
269
- new_inputs = dataclasses .replace (
270
- opt_inputs ,
271
- q = 1 ,
272
- batch_initial_conditions = None ,
273
- return_best_only = True ,
274
- sequential = False ,
275
- timeout_sec = timeout_sec ,
276
- )
277
299
for i in range (opt_inputs .q ):
300
+ if n_acq > 1 :
301
+ acq_function = acq_function_sequence [i ]
302
+ new_kwargs ["acq_function" ] = acq_function
303
+ new_inputs = dataclasses .replace (opt_inputs , ** new_kwargs )
304
+ if len (candidate_list ) > 0 :
305
+ candidates = torch .cat (candidate_list , dim = - 2 )
306
+ new_inputs .acq_function .set_X_pending (
307
+ torch .cat ([base_X_pending [i % n_acq ], candidates ], dim = - 2 )
308
+ if base_X_pending [i % n_acq ] is not None
309
+ else candidates
310
+ )
278
311
candidate , acq_value = _optimize_acqf_batch (new_inputs )
279
312
280
313
candidate_list .append (candidate )
281
314
acq_value_list .append (acq_value )
282
- candidates = torch .cat (candidate_list , dim = - 2 )
283
- new_inputs .acq_function .set_X_pending (
284
- torch .cat ([base_X_pending , candidates ], dim = - 2 )
285
- if base_X_pending is not None
286
- else candidates
287
- )
315
+
288
316
logger .info (f"Generated sequential candidate { i + 1 } of { opt_inputs .q } " )
289
- opt_inputs .acq_function .set_X_pending (base_X_pending )
317
+ model_name = type (new_inputs .acq_function .model ).__name__
318
+ logger .debug (f"Used model { model_name } for candidate generation." )
319
+ candidates = torch .cat (candidate_list , dim = - 2 )
320
+ # Re-set X_pendings on the acquisitions to base values
321
+ for acqf , X_pending in zip (acq_function_sequence , base_X_pending ):
322
+ acqf .set_X_pending (X_pending )
290
323
return candidates , torch .stack (acq_value_list )
291
324
292
325
@@ -517,7 +550,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
517
550
518
551
519
552
def optimize_acqf (
520
- acq_function : AcquisitionFunction ,
553
+ acq_function : AcquisitionFunction | None ,
521
554
bounds : Tensor ,
522
555
q : int ,
523
556
num_restarts : int ,
@@ -532,6 +565,7 @@ def optimize_acqf(
532
565
return_best_only : bool = True ,
533
566
gen_candidates : TGenCandidates | None = None ,
534
567
sequential : bool = False ,
568
+ acq_function_sequence : list [AcquisitionFunction ] | None = None ,
535
569
* ,
536
570
ic_generator : TGenInitialConditions | None = None ,
537
571
timeout_sec : float | None = None ,
@@ -627,6 +661,10 @@ def optimize_acqf(
627
661
inputs. Default: `gen_candidates_scipy`
628
662
sequential: If False, uses joint optimization, otherwise uses sequential
629
663
optimization for optimizing multiple joint candidates (q > 1).
664
+ acq_function_sequence: A list of acquisition functions to be optimized
665
+ sequentially. Must be of length q>1, and requires sequential=True. Used
666
+ for ensembling candidates from different acquisition functions. If
667
+ omitted, use `acq_function` to generate all `q` candidates.
630
668
ic_generator: Function for generating initial conditions. Not needed when
631
669
`batch_initial_conditions` are provided. Defaults to
632
670
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
@@ -689,6 +727,7 @@ def optimize_acqf(
689
727
return_full_tree = return_full_tree ,
690
728
retry_on_optimization_warning = retry_on_optimization_warning ,
691
729
ic_gen_kwargs = ic_gen_kwargs ,
730
+ acq_function_sequence = acq_function_sequence ,
692
731
)
693
732
return _optimize_acqf (opt_inputs = opt_acqf_inputs )
694
733
0 commit comments