Skip to content

Commit 5ef3c77

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Add support for missing tasks in mtgp (#2960)
Summary: X-link: facebook/Ax#4121 Currently, cross-validation in Ax fails when using a MTGP if there are multiple metrics and only some metrics have been observed for some tasks. This is a modeling problem, since the model is a ModelListGP and not all MTGPs in the list are required to have the same tasks. Hence when you pass in a test input, the model errors out if there are not observations from that task in the training data. This avoids the error by mapping (optionally) mapping unexpected tasks to the "target task". This does not change the default behavior. For cross-validation in Ax, predictions are discarded if there are no observations for a given (task, metric) pair. This will still error out in Ax if data for the target trial is missing. Differential Revision: D79812024
1 parent 20f1116 commit 5ef3c77

File tree

8 files changed

+311
-96
lines changed

8 files changed

+311
-96
lines changed

botorch/models/fully_bayesian_multitask.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(
227227
outcome_transform: OutcomeTransform | None = None,
228228
input_transform: InputTransform | None = None,
229229
pyro_model: MultitaskSaasPyroModel | None = None,
230+
validate_task_values: bool = True,
230231
) -> None:
231232
r"""Initialize the fully Bayesian multi-task GP model.
232233
@@ -251,6 +252,9 @@ def __init__(
251252
in the model's forward pass.
252253
pyro_model: Optional `PyroModel` that has the same signature as
253254
`MultitaskSaasPyroModel`. Defaults to `MultitaskSaasPyroModel`.
255+
validate_task_values: If True, validate that the task values supplied in the
256+
input are expected tasks values. If false, unexpected task values
257+
will be mapped to the first output_task if supplied.
254258
"""
255259
if not (
256260
train_X.ndim == train_Y.ndim == 2
@@ -288,22 +292,19 @@ def __init__(
288292
# set on `self` below, it will be applied to the posterior in the
289293
# `posterior` method of `MultiTaskGP`.
290294
outcome_transform=None,
295+
all_tasks=all_tasks,
296+
validate_task_values=validate_task_values,
291297
)
292-
if all_tasks is not None and self._expected_task_values != set(all_tasks):
293-
raise NotImplementedError(
294-
"The `all_tasks` argument is not supported by SAAS MTGP. "
295-
f"The training data includes tasks {self._expected_task_values}, "
296-
f"got {all_tasks=}."
297-
)
298298
self.to(train_X)
299-
300299
self.mean_module = None
301300
self.covar_module = None
302301
self.likelihood = None
303302
if pyro_model is None:
304303
pyro_model = MultitaskSaasPyroModel()
304+
# apply task_mapper
305+
x_before, task_idcs, x_after = self._split_inputs(transformed_X)
305306
pyro_model.set_inputs(
306-
train_X=transformed_X,
307+
train_X=torch.cat([x_before, task_idcs, x_after], dim=-1),
307308
train_Y=train_Y,
308309
train_Yvar=train_Yvar,
309310
task_feature=task_feature,

botorch/models/gpytorch.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -789,27 +789,31 @@ def _map_tasks(self, task_values: Tensor) -> Tensor:
789789
A tensor of task indices with the same shape as the input
790790
tensor.
791791
"""
792-
if self._task_mapper is None:
793-
if not (
794-
torch.all(0 <= task_values) and torch.all(task_values < self.num_tasks)
795-
):
796-
raise ValueError(
797-
"Expected all task features in `X` to be between 0 and "
798-
f"self.num_tasks - 1. Got {task_values}."
799-
)
800-
else:
801-
task_values = task_values.long()
792+
long_task_values = task_values.long()
793+
if self._validate_task_values:
794+
if self._task_mapper is None:
795+
if not (
796+
torch.all(0 <= task_values)
797+
and torch.all(task_values < self.num_tasks)
798+
):
799+
raise ValueError(
800+
"Expected all task features in `X` to be between 0 and "
801+
f"self.num_tasks - 1. Got {task_values}."
802+
)
803+
else:
804+
unexpected_task_values = set(
805+
long_task_values.unique().tolist()
806+
).difference(self._expected_task_values)
807+
if len(unexpected_task_values) > 0:
808+
raise ValueError(
809+
"Received invalid raw task values. Expected raw value to be in"
810+
f" {self._expected_task_values}, but got unexpected task"
811+
f" values: {unexpected_task_values}."
812+
)
813+
task_values = self._task_mapper[long_task_values]
814+
elif self._task_mapper is not None:
815+
task_values = self._task_mapper[long_task_values]
802816

803-
unexpected_task_values = set(task_values.unique().tolist()).difference(
804-
self._expected_task_values
805-
)
806-
if len(unexpected_task_values) > 0:
807-
raise ValueError(
808-
"Received invalid raw task values. Expected raw value to be in"
809-
f" {self._expected_task_values}, but got unexpected task values:"
810-
f" {unexpected_task_values}."
811-
)
812-
task_values = self._task_mapper[task_values]
813817
return task_values
814818

815819
def _apply_noise(

botorch/models/multitask.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def __init__(
115115
all_tasks: list[int] | None = None,
116116
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
117117
input_transform: InputTransform | None = None,
118+
validate_task_values: bool = True,
119+
num_tasks_to_model: int | None = None,
118120
) -> None:
119121
r"""Multi-Task GP model using an ICM kernel.
120122
@@ -157,6 +159,11 @@ def __init__(
157159
instantiation of the model.
158160
input_transform: An input transform that is applied in the model's
159161
forward pass.
162+
validate_task_values: If True, validate that the task values supplied in the
163+
input are expected tasks values. If false, unexpected task values
164+
will be mapped to the first output_task if supplied.
165+
num_tasks_to_model: The number of tasks to model. If omitted, model only
166+
the tasks inferred from the training data.
160167
161168
Example:
162169
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -189,7 +196,11 @@ def __init__(
189196
"This is not allowed as it will lead to errors during model training."
190197
)
191198
all_tasks = all_tasks or all_tasks_inferred
192-
self.num_tasks = len(all_tasks)
199+
self.num_tasks = (
200+
len(all_tasks_inferred)
201+
if num_tasks_to_model is None
202+
else num_tasks_to_model
203+
)
193204
if outcome_transform == DEFAULT:
194205
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
195206
if outcome_transform is not None:
@@ -249,17 +260,22 @@ def __init__(
249260

250261
self.covar_module = data_covar_module * task_covar_module
251262
task_mapper = get_task_value_remapping(
252-
task_values=torch.tensor(
253-
all_tasks, dtype=torch.long, device=train_X.device
263+
observed_task_values=torch.tensor(
264+
all_tasks_inferred, dtype=torch.long, device=train_X.device
265+
),
266+
all_task_values=torch.tensor(
267+
sorted(all_tasks), dtype=torch.long, device=train_X.device
254268
),
255269
dtype=train_X.dtype,
270+
default_task_value=None if output_tasks is None else output_tasks[0],
256271
)
257272
self.register_buffer("_task_mapper", task_mapper)
258-
self._expected_task_values = set(all_tasks)
273+
self._expected_task_values = set(all_tasks_inferred)
259274
if input_transform is not None:
260275
self.input_transform = input_transform
261276
if outcome_transform is not None:
262277
self.outcome_transform = outcome_transform
278+
self._validate_task_values = validate_task_values
263279
self.to(train_X)
264280

265281
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
@@ -274,7 +290,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
274290
3-element tuple containing
275291
276292
- A `q x d` or `b x q x d` tensor with features before the task feature
277-
- A `q` or `b x q` tensor with mapped task indices
293+
- A `q` or `b x q x 1` tensor with mapped task indices
278294
- A `q x d` or `b x q x d` tensor with features after the task feature
279295
"""
280296
batch_shape = x.shape[:-2]
@@ -314,7 +330,7 @@ def get_all_tasks(
314330
raise ValueError(f"Must have that -{d} <= task_feature <= {d}")
315331
task_feature = task_feature % (d + 1)
316332
all_tasks = (
317-
train_X[..., task_feature].unique(sorted=True).to(dtype=torch.long).tolist()
333+
train_X[..., task_feature].to(dtype=torch.long).unique(sorted=True).tolist()
318334
)
319335
return all_tasks, task_feature, d
320336

botorch/models/transforms/outcome.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,13 @@ class StratifiedStandardize(Standardize):
511511

512512
def __init__(
513513
self,
514-
task_values: Tensor,
515514
stratification_idx: int,
515+
observed_task_values: Tensor,
516+
all_task_values: Tensor,
516517
batch_shape: torch.Size = torch.Size(), # noqa: B008
517518
min_stdv: float = 1e-8,
518-
# dtype: torch.dtype = torch.double,
519+
dtype: torch.dtype = torch.double,
520+
default_task_value: int | None = None,
519521
) -> None:
520522
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
521523
@@ -528,13 +530,21 @@ def __init__(
528530
batch_shape: The batch_shape of the training targets.
529531
min_stddv: The minimum standard deviation for which to perform
530532
standardization (if lower, only de-mean the data).
533+
default_task_value: The default task value that unexpected tasks are
534+
mapped to. This is used in `get_task_value_remapping`.
535+
531536
"""
532537
OutcomeTransform.__init__(self)
533538
self._stratification_idx = stratification_idx
534-
task_values = task_values.unique(sorted=True)
535-
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.double)
539+
observed_task_values = observed_task_values.unique(sorted=True)
540+
self.strata_mapping = get_task_value_remapping(
541+
observed_task_values=observed_task_values,
542+
all_task_values=all_task_values.unique(sorted=True),
543+
dtype=dtype,
544+
default_task_value=default_task_value,
545+
)
536546
if self.strata_mapping is None:
537-
self.strata_mapping = task_values
547+
self.strata_mapping = observed_task_values
538548
n_strata = self.strata_mapping.shape[0]
539549
self._min_stdv = min_stdv
540550
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))

botorch/models/utils/assorted.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,20 @@ class fantasize(_Flag):
405405
_state: bool = False
406406

407407

408-
def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
409-
"""Construct an mapping of discrete task values to contiguous int-valued floats.
408+
def get_task_value_remapping(
409+
observed_task_values: Tensor,
410+
all_task_values: Tensor,
411+
dtype: torch.dtype,
412+
default_task_value: int | None,
413+
) -> Tensor | None:
414+
"""Construct an mapping of observed task values to contiguous int-valued floats.
410415
411416
Args:
412-
task_values: A sorted long-valued tensor of task values.
417+
observed_task_values: A sorted long-valued tensor of task values.
418+
all_task_values: A sorted long-valued tensor of task values.
413419
dtype: The dtype of the model inputs (e.g. `X`), which the new
414420
task values should have mapped to (e.g. float, double).
421+
default_task_value: The default task value to use for missing task values.
415422
416423
Returns:
417424
A tensor of shape `task_values.max() + 1` that maps task values
@@ -425,17 +432,31 @@ def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor
425432
if dtype not in (torch.float, torch.double):
426433
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
427434
task_range = torch.arange(
428-
len(task_values), dtype=task_values.dtype, device=task_values.device
435+
len(observed_task_values),
436+
dtype=all_task_values.dtype,
437+
device=all_task_values.device,
429438
)
430439
mapper = None
431-
if not torch.equal(task_values, task_range):
440+
441+
if default_task_value is None:
442+
fill_value = float("nan")
443+
else:
444+
mask = observed_task_values == default_task_value
445+
if not mask.any():
446+
fill_value = float("nan")
447+
else:
448+
idx = mask.nonzero().item()
449+
fill_value = task_range[idx]
450+
# if not all tasks are observed or they are not contiguous integers
451+
# then map them to contiguous integers
452+
if not torch.equal(task_range, all_task_values):
432453
# Create a tensor that maps task values to new task values.
433454
# The number of tasks should be small, so this should be quite efficient.
434455
mapper = torch.full(
435-
(int(task_values.max().item()) + 1,),
436-
float("nan"),
456+
(int(all_task_values.max().item()) + 1,),
457+
fill_value,
437458
dtype=dtype,
438-
device=task_values.device,
459+
device=all_task_values.device,
439460
)
440-
mapper[task_values] = task_range.to(dtype=dtype)
461+
mapper[observed_task_values] = task_range.to(dtype=dtype)
441462
return mapper

0 commit comments

Comments
 (0)