Source code for botorch.acquisition.cost_aware

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Cost functions for cost-aware acquisition functions, e.g. multi-fidelity KG.
To be used in a context where there is an objective/cost tradeoff.
"""

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union

import torch
from botorch import settings
from botorch.acquisition.objective import (
    GenericMCObjective,
    IdentityMCObjective,
    MCAcquisitionObjective,
)
from botorch.exceptions.warnings import CostAwareWarning
from botorch.models.deterministic import DeterministicModel
from botorch.models.gpytorch import GPyTorchModel
from botorch.sampling.base import MCSampler
from torch import Tensor
from torch.nn import Module


class CostAwareUtility(Module, ABC):
    r"""
    Abstract base class for cost-aware utilities.

    :meta private:
    """

    @abstractmethod
    def forward(self, X: Tensor, deltas: Tensor, **kwargs: Any) -> Tensor:
        r"""Evaluate the cost-aware utility on the candidates and improvements.

        Args:
            X: A `batch_shape x q x d`-dim Tensor of with `q` `d`-dim design
                points each for each t-batch.
            deltas: A `num_fantasies x batch_shape`-dim Tensor of `num_fantasy`
                samples from the marginal improvement in utility over the
                current state at `X` for each t-batch.

        Returns:
            A `num_fantasies x batch_shape`-dim Tensor of cost-transformed utilities.
        """
        pass  # pragma: no cover


[docs] class GenericCostAwareUtility(CostAwareUtility): r"""Generic cost-aware utility wrapping a callable.""" def __init__(self, cost: Callable[[Tensor, Tensor], Tensor]) -> None: r"""Generic cost-aware utility wrapping a callable. Args: cost: A callable mapping a `batch_shape x q x d'`-dim candidate set to a `batch_shape`-dim tensor of costs """ super().__init__() self._cost_callable: Callable[[Tensor, Tensor], Tensor] = cost
[docs] def forward(self, X: Tensor, deltas: Tensor, **kwargs: Any) -> Tensor: r"""Evaluate the cost function on the candidates and improvements. Args: X: A `batch_shape x q x d'`-dim Tensor of with `q` `d`-dim design points for each t-batch. deltas: A `num_fantasies x batch_shape`-dim Tensor of `num_fantasy` samples from the marginal improvement in utility over the current state at `X` for each t-batch. Returns: A `num_fantasies x batch_shape`-dim Tensor of cost-weighted utilities. """ return self._cost_callable(X, deltas)
[docs] class InverseCostWeightedUtility(CostAwareUtility): r"""A cost-aware utility using inverse cost weighting based on a model. Computes the cost-aware utility by inverse-weighting samples `U = (u_1, ..., u_N)` of the increase in utility. If `use_mean=True`, this uses the posterior mean `mean_cost` of the cost model, i.e. `weighted utility = mean(U) / mean_cost`. If `use_mean=False`, it uses samples `C = (c_1, ..., c_N)` from the posterior of the cost model and performs the inverse weighting on the sample level: `weighted utility = mean(u_1 / c_1, ..., u_N / c_N)`. Where values in (u_1, ..., u_N) are negative, or for mean(U) < 0, the weighted utility is instead calculated via scaling by the cost, i.e. if `use_mean=True`: `weighted_utility = mean(U) * mean_cost` and if `use_mean=False`: `weighted utility = mean(u_1 * c_1, u_2 / c_2, u_3 * c_3, ..., u_N / c_N)`, depending on whether (`u_*` >= 0), as with `u_2` and `u_N` in this case, or (`u_*` < 0) as with `u_1` and `u_3`. The cost is additive across multiple elements of a q-batch. """ def __init__( self, cost_model: Union[DeterministicModel, GPyTorchModel], use_mean: bool = True, cost_objective: Optional[MCAcquisitionObjective] = None, min_cost: float = 1e-2, ) -> None: r"""Cost-aware utility that weights increase in utility by inverse cost. For negative increases in utility, the utility is instead scaled by the cost. See the class description for more information. Args: cost_model: A model of the cost of evaluating a candidate set `X`, where `X` are the same features as in the model for the acquisition function this is to be used with. If no cost_objective is specified, the outputs are required to be non-negative. use_mean: If True, use the posterior mean, otherwise use posterior samples from the cost model. cost_objective: If specified, transform the posterior mean / the posterior samples from the cost model. This can be used e.g. to un-transform predictions/samples of a cost model fit on the log-transformed cost (often done to ensure non-negativity). If the cost model is multi-output, then by default this will sum the cost across outputs. min_cost: A value used to clamp the cost samples so that they are not too close to zero, which may cause numerical issues. Returns: The inverse-cost-weighted utility. """ super().__init__() if cost_objective is None: if cost_model.num_outputs == 1: cost_objective = IdentityMCObjective() else: # sum over outputs cost_objective = GenericMCObjective(lambda Y, X: Y.sum(dim=-1)) self.cost_model = cost_model self.cost_objective = cost_objective self._use_mean = use_mean self._min_cost = min_cost
[docs] def forward( self, X: Tensor, deltas: Tensor, sampler: Optional[MCSampler] = None, X_evaluation_mask: Optional[Tensor] = None, **kwargs: Any, ) -> Tensor: r"""Evaluate the cost function on the candidates and improvements. Note that negative values of `deltas` are instead scaled by the cost, and not inverse-weighted. See the class description for more information. Args: X: A `batch_shape x q x d`-dim Tensor of with `q` `d`-dim design points each for each t-batch. deltas: A `num_fantasies x batch_shape`-dim Tensor of `num_fantasy` samples from the marginal improvement in utility over the current state at `X` for each t-batch. sampler: A sampler used for sampling from the posterior of the cost model (required if `use_mean=False`, ignored if `use_mean=True`). X_evaluation_mask: A `q x m`-dim boolean tensor indicating which outcomes should be evaluated for each design in the batch. Returns: A `num_fantasies x batch_shape`-dim Tensor of cost-weighted utilities. """ if not self._use_mean and sampler is None: raise RuntimeError("Must provide `sampler` if `use_mean=False`") if X_evaluation_mask is not None: # TODO: support different evaluation masks for each X. This requires # either passing evaluation_mask to `cost_model.posterior` # or assuming that evaluating `cost_model.posterior(X)` on all # `q` points and then only selecting the costs for relevant points # does not change the cost function for each point. This would not be # true for instance if the incremental cost of evaluating an additional # point decreased as the number of points increased. if not all( torch.equal(X_evaluation_mask[0], X_evaluation_mask[i]) for i in range(1, X_evaluation_mask.shape[0]) ): raise NotImplementedError( "Currently, all candidates must be evaluated on the same outputs." ) output_indices = X_evaluation_mask[0].nonzero().view(-1).tolist() else: output_indices = None cost_posterior = self.cost_model.posterior(X, output_indices=output_indices) if self._use_mean: cost = cost_posterior.mean # batch_shape x q x m' else: # This will be of shape num_fantasies x batch_shape x q x m' cost = sampler(cost_posterior) # TODO: Make sure this doesn't change base samples in-place cost = self.cost_objective(cost) # Ensure non-negativity of the cost if settings.debug.on(): if torch.any(cost < -1e-7): warnings.warn( "Encountered negative cost values in InverseCostWeightedUtility", CostAwareWarning, ) # clamp (away from zero) and sum cost across elements of the q-batch - # this will be of shape `num_fantasies x batch_shape` or `batch_shape` cost = cost.clamp_min(self._min_cost).sum(dim=-1) # compute and return the ratio on the sample level - If `use_mean=True` # this operation involves broadcasting the cost across fantasies. # We multiply by the cost if the deltas are <= 0, see discussion #2914 return torch.where(deltas > 0, deltas / cost, deltas * cost)