Skip to content

Add support for missing tasks in mtgp #2960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(
outcome_transform: OutcomeTransform | None = None,
input_transform: InputTransform | None = None,
pyro_model: MultitaskSaasPyroModel | None = None,
validate_task_values: bool = True,
) -> None:
r"""Initialize the fully Bayesian multi-task GP model.

Expand All @@ -251,6 +252,9 @@ def __init__(
in the model's forward pass.
pyro_model: Optional `PyroModel` that has the same signature as
`MultitaskSaasPyroModel`. Defaults to `MultitaskSaasPyroModel`.
validate_task_values: If True, validate that the task values supplied in the
input are expected tasks values. If false, unexpected task values
will be mapped to the first output_task if supplied.
"""
if not (
train_X.ndim == train_Y.ndim == 2
Expand Down Expand Up @@ -288,22 +292,19 @@ def __init__(
# set on `self` below, it will be applied to the posterior in the
# `posterior` method of `MultiTaskGP`.
outcome_transform=None,
all_tasks=all_tasks,
validate_task_values=validate_task_values,
)
if all_tasks is not None and self._expected_task_values != set(all_tasks):
raise NotImplementedError(
"The `all_tasks` argument is not supported by SAAS MTGP. "
f"The training data includes tasks {self._expected_task_values}, "
f"got {all_tasks=}."
)
self.to(train_X)

self.mean_module = None
self.covar_module = None
self.likelihood = None
if pyro_model is None:
pyro_model = MultitaskSaasPyroModel()
# apply task_mapper
x_before, task_idcs, x_after = self._split_inputs(transformed_X)
pyro_model.set_inputs(
train_X=transformed_X,
train_X=torch.cat([x_before, task_idcs, x_after], dim=-1),
train_Y=train_Y,
train_Yvar=train_Yvar,
task_feature=task_feature,
Expand Down
33 changes: 0 additions & 33 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,39 +779,6 @@ class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
"long-format" multi-task GP in the style of `MultiTaskGP`.
"""

def _map_tasks(self, task_values: Tensor) -> Tensor:
"""Map raw task values to the task indices used by the model.

Args:
task_values: A tensor of task values.

Returns:
A tensor of task indices with the same shape as the input
tensor.
"""
if self._task_mapper is None:
if not (
torch.all(0 <= task_values) and torch.all(task_values < self.num_tasks)
):
raise ValueError(
"Expected all task features in `X` to be between 0 and "
f"self.num_tasks - 1. Got {task_values}."
)
else:
task_values = task_values.long()

unexpected_task_values = set(task_values.unique().tolist()).difference(
self._expected_task_values
)
if len(unexpected_task_values) > 0:
raise ValueError(
"Received invalid raw task values. Expected raw value to be in"
f" {self._expected_task_values}, but got unexpected task values:"
f" {unexpected_task_values}."
)
task_values = self._task_mapper[task_values]
return task_values

def _apply_noise(
self,
X: Tensor,
Expand Down
58 changes: 52 additions & 6 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
all_tasks: list[int] | None = None,
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
input_transform: InputTransform | None = None,
validate_task_values: bool = True,
) -> None:
r"""Multi-Task GP model using an ICM kernel.

Expand Down Expand Up @@ -157,6 +158,9 @@ def __init__(
instantiation of the model.
input_transform: An input transform that is applied in the model's
forward pass.
validate_task_values: If True, validate that the task values supplied in the
input are expected tasks values. If false, unexpected task values
will be mapped to the first output_task if supplied.

Example:
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
Expand Down Expand Up @@ -189,7 +193,7 @@ def __init__(
"This is not allowed as it will lead to errors during model training."
)
all_tasks = all_tasks or all_tasks_inferred
self.num_tasks = len(all_tasks)
self.num_tasks = len(all_tasks_inferred)
if outcome_transform == DEFAULT:
outcome_transform = Standardize(m=1, batch_shape=train_X.shape[:-2])
if outcome_transform is not None:
Expand Down Expand Up @@ -249,19 +253,61 @@ def __init__(

self.covar_module = data_covar_module * task_covar_module
task_mapper = get_task_value_remapping(
task_values=torch.tensor(
all_tasks, dtype=torch.long, device=train_X.device
observed_task_values=torch.tensor(
all_tasks_inferred, dtype=torch.long, device=train_X.device
),
all_task_values=torch.tensor(
sorted(all_tasks), dtype=torch.long, device=train_X.device
),
dtype=train_X.dtype,
default_task_value=None if output_tasks is None else output_tasks[0],
)
self.register_buffer("_task_mapper", task_mapper)
self._expected_task_values = set(all_tasks)
self._expected_task_values = set(all_tasks_inferred)
if input_transform is not None:
self.input_transform = input_transform
if outcome_transform is not None:
self.outcome_transform = outcome_transform
self._validate_task_values = validate_task_values
self.to(train_X)

def _map_tasks(self, task_values: Tensor) -> Tensor:
"""Map raw task values to the task indices used by the model.

Args:
task_values: A tensor of task values.

Returns:
A tensor of task indices with the same shape as the input
tensor.
"""
long_task_values = task_values.long()
if self._validate_task_values:
if self._task_mapper is None:
if not (
torch.all(0 <= task_values)
and torch.all(task_values < self.num_tasks)
):
raise ValueError(
"Expected all task features in `X` to be between 0 and "
f"self.num_tasks - 1. Got {task_values}."
)
else:
unexpected_task_values = set(
long_task_values.unique().tolist()
).difference(self._expected_task_values)
if len(unexpected_task_values) > 0:
raise ValueError(
"Received invalid raw task values. Expected raw value to be in"
f" {self._expected_task_values}, but got unexpected task"
f" values: {unexpected_task_values}."
)
task_values = self._task_mapper[long_task_values]
elif self._task_mapper is not None:
task_values = self._task_mapper[long_task_values]

return task_values

def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
r"""Extracts features before task feature, task indices, and features after
the task feature.
Expand All @@ -274,7 +320,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
3-element tuple containing

- A `q x d` or `b x q x d` tensor with features before the task feature
- A `q` or `b x q` tensor with mapped task indices
- A `q` or `b x q x 1` tensor with mapped task indices
- A `q x d` or `b x q x d` tensor with features after the task feature
"""
batch_shape = x.shape[:-2]
Expand Down Expand Up @@ -314,7 +360,7 @@ def get_all_tasks(
raise ValueError(f"Must have that -{d} <= task_feature <= {d}")
task_feature = task_feature % (d + 1)
all_tasks = (
train_X[..., task_feature].unique(sorted=True).to(dtype=torch.long).tolist()
train_X[..., task_feature].to(dtype=torch.long).unique(sorted=True).tolist()
)
return all_tasks, task_feature, d

Expand Down
20 changes: 15 additions & 5 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,13 @@ class StratifiedStandardize(Standardize):

def __init__(
self,
task_values: Tensor,
stratification_idx: int,
observed_task_values: Tensor,
all_task_values: Tensor,
batch_shape: torch.Size = torch.Size(), # noqa: B008
min_stdv: float = 1e-8,
# dtype: torch.dtype = torch.double,
dtype: torch.dtype = torch.double,
default_task_value: int | None = None,
) -> None:
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.

Expand All @@ -528,13 +530,21 @@ def __init__(
batch_shape: The batch_shape of the training targets.
min_stddv: The minimum standard deviation for which to perform
standardization (if lower, only de-mean the data).
default_task_value: The default task value that unexpected tasks are
mapped to. This is used in `get_task_value_remapping`.

"""
OutcomeTransform.__init__(self)
self._stratification_idx = stratification_idx
task_values = task_values.unique(sorted=True)
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.double)
observed_task_values = observed_task_values.unique(sorted=True)
self.strata_mapping = get_task_value_remapping(
observed_task_values=observed_task_values,
all_task_values=all_task_values.unique(sorted=True),
dtype=dtype,
default_task_value=default_task_value,
)
if self.strata_mapping is None:
self.strata_mapping = task_values
self.strata_mapping = observed_task_values
n_strata = self.strata_mapping.shape[0]
self._min_stdv = min_stdv
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))
Expand Down
39 changes: 30 additions & 9 deletions botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,20 @@ class fantasize(_Flag):
_state: bool = False


def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
"""Construct an mapping of discrete task values to contiguous int-valued floats.
def get_task_value_remapping(
observed_task_values: Tensor,
all_task_values: Tensor,
dtype: torch.dtype,
default_task_value: int | None,
) -> Tensor | None:
"""Construct an mapping of observed task values to contiguous int-valued floats.

Args:
task_values: A sorted long-valued tensor of task values.
observed_task_values: A sorted long-valued tensor of task values.
all_task_values: A sorted long-valued tensor of task values.
dtype: The dtype of the model inputs (e.g. `X`), which the new
task values should have mapped to (e.g. float, double).
default_task_value: The default task value to use for missing task values.

Returns:
A tensor of shape `task_values.max() + 1` that maps task values
Expand All @@ -425,17 +432,31 @@ def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor
if dtype not in (torch.float, torch.double):
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
task_range = torch.arange(
len(task_values), dtype=task_values.dtype, device=task_values.device
len(observed_task_values),
dtype=all_task_values.dtype,
device=all_task_values.device,
)
mapper = None
if not torch.equal(task_values, task_range):

if default_task_value is None:
fill_value = float("nan")
else:
mask = observed_task_values == default_task_value
if not mask.any():
fill_value = float("nan")
else:
idx = mask.nonzero().item()
fill_value = task_range[idx]
# if not all tasks are observed or they are not contiguous integers
# then map them to contiguous integers
if not torch.equal(task_range, all_task_values):
# Create a tensor that maps task values to new task values.
# The number of tasks should be small, so this should be quite efficient.
mapper = torch.full(
(int(task_values.max().item()) + 1,),
float("nan"),
(int(all_task_values.max().item()) + 1,),
fill_value,
dtype=dtype,
device=task_values.device,
device=all_task_values.device,
)
mapper[task_values] = task_range.to(dtype=dtype)
mapper[observed_task_values] = task_range.to(dtype=dtype)
return mapper
Loading
Loading