Skip to content

Commit 09dd4dd

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Add support for missing tasks in mtgp (#2960)
Summary: X-link: facebook/Ax#4121 Currently, cross-validation in Ax fails when using a MTGP if there are multiple metrics and only some metrics have been observed for some tasks. This is a modeling problem, since the model is a ModelListGP and not all MTGPs in the list are required to have the same tasks. Hence when you pass in a test input, the model errors out if there are not observations from that task in the training data. This avoids the error by mapping (optionally) mapping unexpected tasks to the "target task". This does not change the default behavior. For cross-validation in Ax, predictions are discarded if there are no observations for a given (task, metric) pair. This will still error out in Ax if data for the target trial is missing. Differential Revision: D79812024
1 parent 20f1116 commit 09dd4dd

File tree

8 files changed

+319
-109
lines changed

8 files changed

+319
-109
lines changed

botorch/models/fully_bayesian_multitask.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(
227227
outcome_transform: OutcomeTransform | None = None,
228228
input_transform: InputTransform | None = None,
229229
pyro_model: MultitaskSaasPyroModel | None = None,
230+
validate_task_values: bool = True,
230231
) -> None:
231232
r"""Initialize the fully Bayesian multi-task GP model.
232233
@@ -251,6 +252,9 @@ def __init__(
251252
in the model's forward pass.
252253
pyro_model: Optional `PyroModel` that has the same signature as
253254
`MultitaskSaasPyroModel`. Defaults to `MultitaskSaasPyroModel`.
255+
validate_task_values: If True, validate that the task values supplied in the
256+
input are expected tasks values. If false, unexpected task values
257+
will be mapped to the first output_task if supplied.
254258
"""
255259
if not (
256260
train_X.ndim == train_Y.ndim == 2
@@ -288,22 +292,19 @@ def __init__(
288292
# set on `self` below, it will be applied to the posterior in the
289293
# `posterior` method of `MultiTaskGP`.
290294
outcome_transform=None,
295+
all_tasks=all_tasks,
296+
validate_task_values=validate_task_values,
291297
)
292-
if all_tasks is not None and self._expected_task_values != set(all_tasks):
293-
raise NotImplementedError(
294-
"The `all_tasks` argument is not supported by SAAS MTGP. "
295-
f"The training data includes tasks {self._expected_task_values}, "
296-
f"got {all_tasks=}."
297-
)
298298
self.to(train_X)
299-
300299
self.mean_module = None
301300
self.covar_module = None
302301
self.likelihood = None
303302
if pyro_model is None:
304303
pyro_model = MultitaskSaasPyroModel()
304+
# apply task_mapper
305+
x_before, task_idcs, x_after = self._split_inputs(transformed_X)
305306
pyro_model.set_inputs(
306-
train_X=transformed_X,
307+
train_X=torch.cat([x_before, task_idcs, x_after], dim=-1),
307308
train_Y=train_Y,
308309
train_Yvar=train_Yvar,
309310
task_feature=task_feature,

botorch/models/gpytorch.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -779,39 +779,6 @@ class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
779779
"long-format" multi-task GP in the style of `MultiTaskGP`.
780780
"""
781781

782-
def _map_tasks(self, task_values: Tensor) -> Tensor:
783-
"""Map raw task values to the task indices used by the model.
784-
785-
Args:
786-
task_values: A tensor of task values.
787-
788-
Returns:
789-
A tensor of task indices with the same shape as the input
790-
tensor.
791-
"""
792-
if self._task_mapper is None:
793-
if not (
794-
torch.all(0 <= task_values) and torch.all(task_values < self.num_tasks)
795-
):
796-
raise ValueError(
797-
"Expected all task features in `X` to be between 0 and "
798-
f"self.num_tasks - 1. Got {task_values}."
799-
)
800-
else:
801-
task_values = task_values.long()
802-
803-
unexpected_task_values = set(task_values.unique().tolist()).difference(
804-
self._expected_task_values
805-
)
806-
if len(unexpected_task_values) > 0:
807-
raise ValueError(
808-
"Received invalid raw task values. Expected raw value to be in"
809-
f" {self._expected_task_values}, but got unexpected task values:"
810-
f" {unexpected_task_values}."
811-
)
812-
task_values = self._task_mapper[task_values]
813-
return task_values
814-
815782
def _apply_noise(
816783
self,
817784
X: Tensor,

botorch/models/multitask.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
all_tasks: list[int] | None = None,
116116
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
117117
input_transform: InputTransform | None = None,
118+
validate_task_values: bool = True,
118119
) -> None:
119120
r"""Multi-Task GP model using an ICM kernel.
120121
@@ -157,6 +158,9 @@ def __init__(
157158
instantiation of the model.
158159
input_transform: An input transform that is applied in the model's
159160
forward pass.
161+
validate_task_values: If True, validate that the task values supplied in the
162+
input are expected tasks values. If false, unexpected task values
163+
will be mapped to the first output_task if supplied.
160164
161165
Example:
162166
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -189,7 +193,7 @@ def __init__(
189193
"This is not allowed as it will lead to errors during model training."
190194
)
191195
all_tasks = all_tasks or all_tasks_inferred
192-
self.num_tasks = len(all_tasks)
196+
self.num_tasks = len(all_tasks_inferred)
193197
if outcome_transform == DEFAULT:
194198
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
195199
if outcome_transform is not None:
@@ -249,19 +253,61 @@ def __init__(
249253

250254
self.covar_module = data_covar_module * task_covar_module
251255
task_mapper = get_task_value_remapping(
252-
task_values=torch.tensor(
253-
all_tasks, dtype=torch.long, device=train_X.device
256+
observed_task_values=torch.tensor(
257+
all_tasks_inferred, dtype=torch.long, device=train_X.device
258+
),
259+
all_task_values=torch.tensor(
260+
sorted(all_tasks), dtype=torch.long, device=train_X.device
254261
),
255262
dtype=train_X.dtype,
263+
default_task_value=None if output_tasks is None else output_tasks[0],
256264
)
257265
self.register_buffer("_task_mapper", task_mapper)
258-
self._expected_task_values = set(all_tasks)
266+
self._expected_task_values = set(all_tasks_inferred)
259267
if input_transform is not None:
260268
self.input_transform = input_transform
261269
if outcome_transform is not None:
262270
self.outcome_transform = outcome_transform
271+
self._validate_task_values = validate_task_values
263272
self.to(train_X)
264273

274+
def _map_tasks(self, task_values: Tensor) -> Tensor:
275+
"""Map raw task values to the task indices used by the model.
276+
277+
Args:
278+
task_values: A tensor of task values.
279+
280+
Returns:
281+
A tensor of task indices with the same shape as the input
282+
tensor.
283+
"""
284+
long_task_values = task_values.long()
285+
if self._validate_task_values:
286+
if self._task_mapper is None:
287+
if not (
288+
torch.all(0 <= task_values)
289+
and torch.all(task_values < self.num_tasks)
290+
):
291+
raise ValueError(
292+
"Expected all task features in `X` to be between 0 and "
293+
f"self.num_tasks - 1. Got {task_values}."
294+
)
295+
else:
296+
unexpected_task_values = set(
297+
long_task_values.unique().tolist()
298+
).difference(self._expected_task_values)
299+
if len(unexpected_task_values) > 0:
300+
raise ValueError(
301+
"Received invalid raw task values. Expected raw value to be in"
302+
f" {self._expected_task_values}, but got unexpected task"
303+
f" values: {unexpected_task_values}."
304+
)
305+
task_values = self._task_mapper[long_task_values]
306+
elif self._task_mapper is not None:
307+
task_values = self._task_mapper[long_task_values]
308+
309+
return task_values
310+
265311
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
266312
r"""Extracts features before task feature, task indices, and features after
267313
the task feature.
@@ -274,7 +320,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
274320
3-element tuple containing
275321
276322
- A `q x d` or `b x q x d` tensor with features before the task feature
277-
- A `q` or `b x q` tensor with mapped task indices
323+
- A `q` or `b x q x 1` tensor with mapped task indices
278324
- A `q x d` or `b x q x d` tensor with features after the task feature
279325
"""
280326
batch_shape = x.shape[:-2]
@@ -314,7 +360,7 @@ def get_all_tasks(
314360
raise ValueError(f"Must have that -{d} <= task_feature <= {d}")
315361
task_feature = task_feature % (d + 1)
316362
all_tasks = (
317-
train_X[..., task_feature].unique(sorted=True).to(dtype=torch.long).tolist()
363+
train_X[..., task_feature].to(dtype=torch.long).unique(sorted=True).tolist()
318364
)
319365
return all_tasks, task_feature, d
320366

botorch/models/transforms/outcome.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,13 @@ class StratifiedStandardize(Standardize):
511511

512512
def __init__(
513513
self,
514-
task_values: Tensor,
515514
stratification_idx: int,
515+
observed_task_values: Tensor,
516+
all_task_values: Tensor,
516517
batch_shape: torch.Size = torch.Size(), # noqa: B008
517518
min_stdv: float = 1e-8,
518-
# dtype: torch.dtype = torch.double,
519+
dtype: torch.dtype = torch.double,
520+
default_task_value: int | None = None,
519521
) -> None:
520522
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
521523
@@ -528,13 +530,21 @@ def __init__(
528530
batch_shape: The batch_shape of the training targets.
529531
min_stddv: The minimum standard deviation for which to perform
530532
standardization (if lower, only de-mean the data).
533+
default_task_value: The default task value that unexpected tasks are
534+
mapped to. This is used in `get_task_value_remapping`.
535+
531536
"""
532537
OutcomeTransform.__init__(self)
533538
self._stratification_idx = stratification_idx
534-
task_values = task_values.unique(sorted=True)
535-
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.double)
539+
observed_task_values = observed_task_values.unique(sorted=True)
540+
self.strata_mapping = get_task_value_remapping(
541+
observed_task_values=observed_task_values,
542+
all_task_values=all_task_values.unique(sorted=True),
543+
dtype=dtype,
544+
default_task_value=default_task_value,
545+
)
536546
if self.strata_mapping is None:
537-
self.strata_mapping = task_values
547+
self.strata_mapping = observed_task_values
538548
n_strata = self.strata_mapping.shape[0]
539549
self._min_stdv = min_stdv
540550
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))

botorch/models/utils/assorted.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,20 @@ class fantasize(_Flag):
405405
_state: bool = False
406406

407407

408-
def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
409-
"""Construct an mapping of discrete task values to contiguous int-valued floats.
408+
def get_task_value_remapping(
409+
observed_task_values: Tensor,
410+
all_task_values: Tensor,
411+
dtype: torch.dtype,
412+
default_task_value: int | None,
413+
) -> Tensor | None:
414+
"""Construct an mapping of observed task values to contiguous int-valued floats.
410415
411416
Args:
412-
task_values: A sorted long-valued tensor of task values.
417+
observed_task_values: A sorted long-valued tensor of task values.
418+
all_task_values: A sorted long-valued tensor of task values.
413419
dtype: The dtype of the model inputs (e.g. `X`), which the new
414420
task values should have mapped to (e.g. float, double).
421+
default_task_value: The default task value to use for missing task values.
415422
416423
Returns:
417424
A tensor of shape `task_values.max() + 1` that maps task values
@@ -425,17 +432,31 @@ def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor
425432
if dtype not in (torch.float, torch.double):
426433
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
427434
task_range = torch.arange(
428-
len(task_values), dtype=task_values.dtype, device=task_values.device
435+
len(observed_task_values),
436+
dtype=all_task_values.dtype,
437+
device=all_task_values.device,
429438
)
430439
mapper = None
431-
if not torch.equal(task_values, task_range):
440+
441+
if default_task_value is None:
442+
fill_value = float("nan")
443+
else:
444+
mask = observed_task_values == default_task_value
445+
if not mask.any():
446+
fill_value = float("nan")
447+
else:
448+
idx = mask.nonzero().item()
449+
fill_value = task_range[idx]
450+
# if not all tasks are observed or they are not contiguous integers
451+
# then map them to contiguous integers
452+
if not torch.equal(task_range, all_task_values):
432453
# Create a tensor that maps task values to new task values.
433454
# The number of tasks should be small, so this should be quite efficient.
434455
mapper = torch.full(
435-
(int(task_values.max().item()) + 1,),
436-
float("nan"),
456+
(int(all_task_values.max().item()) + 1,),
457+
fill_value,
437458
dtype=dtype,
438-
device=task_values.device,
459+
device=all_task_values.device,
439460
)
440-
mapper[task_values] = task_range.to(dtype=dtype)
461+
mapper[observed_task_values] = task_range.to(dtype=dtype)
441462
return mapper

0 commit comments

Comments
 (0)