Source code for botorch.acquisition.multi_objective.parego

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional, Union

import torch
from botorch.acquisition.logei import qLogNoisyExpectedImprovement, TAU_MAX, TAU_RELU
from botorch.acquisition.multi_objective.base import MultiObjectiveMCAcquisitionFunction
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
from botorch.acquisition.objective import GenericMCObjective
from botorch.models.model import Model
from botorch.posteriors.fully_bayesian import MCMC_DIM
from botorch.sampling.base import MCSampler
from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization
from botorch.utils.sampling import sample_simplex
from botorch.utils.transforms import is_ensemble
from torch import Tensor


[docs] class qLogNParEGO(qLogNoisyExpectedImprovement, MultiObjectiveMCAcquisitionFunction): def __init__( self, model: Model, X_baseline: Tensor, scalarization_weights: Optional[Tensor] = None, sampler: Optional[MCSampler] = None, objective: Optional[MCMultiOutputObjective] = None, constraints: Optional[list[Callable[[Tensor], Tensor]]] = None, X_pending: Optional[Tensor] = None, eta: Union[Tensor, float] = 1e-3, fat: bool = True, prune_baseline: bool = False, cache_root: bool = True, tau_relu: float = TAU_RELU, tau_max: float = TAU_MAX, ) -> None: r"""q-LogNParEGO supporting m >= 2 outcomes. This acquisition function utilizes qLogNEI to compute the expected improvement over Chebyshev scalarization of the objectives. This is adapted from qNParEGO proposed in [Daulton2020qehvi]_ to utilize log-improvement acquisition functions of [Ament2023logei]_. See [Knowles2005]_ for the original ParEGO algorithm. This implementation assumes maximization of all objectives. If any of the model outputs are to be minimized, either an `objective` should be used to negate the model outputs or the `scalarization_weights` should be provided with negative weights for the outputs to be minimized. Args: model: A fitted multi-output model, producing outputs for `m` objectives and any number of outcome constraints. NOTE: The model posterior must have a `mean` attribute. X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points that have already been observed. These points are considered as the potential best design point. scalarization_weights: A `m`-dim Tensor of weights to be used in the Chebyshev scalarization. If omitted, samples from the unit simplex. sampler: The sampler used to draw base samples. See `MCAcquisitionFunction` more details. objective: The MultiOutputMCAcquisitionObjective under which the samples are evaluated before applying Chebyshev scalarization. Defaults to `IdentityMultiOutputObjective()`. constraints: A list of constraint callables which map a Tensor of posterior samples of dimension `sample_shape x batch-shape x q x m'`-dim to a `sample_shape x batch-shape x q`-dim Tensor. The associated constraints are satisfied if `constraint(samples) < 0`. X_pending: A `batch_shape x q' x d`-dim Tensor of `q'` design points that have points that have been submitted for function evaluation but have not yet been evaluated. Concatenated into `X` upon forward call. Copied and set to have no gradient. eta: Temperature parameter(s) governing the smoothness of the sigmoid approximation to the constraint indicators. See the docs of `compute_(log_)smoothed_constraint_indicator` for details. fat: Toggles the logarithmic / linear asymptotic behavior of the smooth approximation to the ReLU. prune_baseline: If True, remove points in `X_baseline` that are highly unlikely to be the best point. This can significantly improve performance and is generally recommended. In order to customize pruning parameters, instead manually call `botorch.acquisition.utils.prune_inferior_points` on `X_baseline` before instantiating the acquisition function. cache_root: A boolean indicating whether to cache the root decomposition over `X_baseline` and use low-rank updates. tau_max: Temperature parameter controlling the sharpness of the smooth approximations to max. tau_relu: Temperature parameter controlling the sharpness of the smooth approximations to ReLU. """ MultiObjectiveMCAcquisitionFunction.__init__( self, model=model, sampler=sampler, objective=objective, constraints=constraints, eta=eta, ) org_objective = self.objective # Create the composite objective. with torch.no_grad(): Y_baseline = org_objective(model.posterior(X_baseline).mean) if is_ensemble(model): Y_baseline = torch.mean(Y_baseline, dim=MCMC_DIM) scalarization_weights = ( scalarization_weights if scalarization_weights is not None else sample_simplex( d=Y_baseline.shape[-1], device=X_baseline.device, dtype=X_baseline.dtype ).view(-1) ) chebyshev_scalarization = get_chebyshev_scalarization( weights=scalarization_weights, Y=Y_baseline, ) composite_objective = GenericMCObjective( objective=lambda samples, X=None: chebyshev_scalarization( org_objective(samples=samples, X=X), X=X ), ) qLogNoisyExpectedImprovement.__init__( self, model=model, X_baseline=X_baseline, sampler=sampler, # This overwrites self.objective with the composite objective. objective=composite_objective, X_pending=X_pending, constraints=constraints, eta=eta, fat=fat, prune_baseline=prune_baseline, cache_root=cache_root, tau_max=tau_max, tau_relu=tau_relu, ) # Set these after __init__ calls so that they're not overwritten / deleted. # These are intended mainly for easier debugging & transparency. self._org_objective: MCMultiOutputObjective = org_objective self.chebyshev_scalarization: Callable[[Tensor, Optional[Tensor]], Tensor] = ( chebyshev_scalarization ) self.scalarization_weights: Tensor = scalarization_weights self.Y_baseline: Tensor = Y_baseline