Source code for botorch.models.cost

#! /usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Cost models to be used with multi-fidelity optimization.
"""

from __future__ import annotations

from typing import Dict, Optional

import torch
from torch import Tensor

from .deterministic import DeterministicModel


[docs]class AffineFidelityCostModel(DeterministicModel): r"""Affine cost model operating on fidelity parameters. For each (q-batch) element of a candidate set `X`, this module computes a cost of the form cost = fixed_cost + sum_j weights[j] * X[fidelity_dims[j]] """ def __init__( self, fidelity_weights: Optional[Dict[int, float]] = None, fixed_cost: float = 0.01, ) -> None: r"""Affine cost model operating on fidelity parameters. Args: fidelity_weights: A dictionary mapping a subset of columns of `X` (the fidelity parameters) to it's associated weight in the affine cost expression. If omitted, assumes that the last column of X is the fidelity parameter with a weight of 1.0. fixed_cost: The fixed cost of running a single candidate point (i.e. an element of a q-batch). """ if fidelity_weights is None: fidelity_weights = {-1: 1.0} super().__init__() self.fidelity_dims = sorted(fidelity_weights) self.fixed_cost = fixed_cost weights = torch.tensor([fidelity_weights[i] for i in self.fidelity_dims]) self.register_buffer("weights", weights) self._num_outputs = 1
[docs] def forward(self, X: Tensor) -> Tensor: r"""Evaluate the cost on a candidate set X. Computes a cost of the form cost = fixed_cost + sum_j weights[j] * X[fidelity_dims[j]] for each element of the q-batch Args: X: A `batch_shape x q x d'`-dim tensor of candidate points. Returns: A `batch_shape x q x 1`-dim tensor of costs. """ # TODO: Consider different aggregation (i.e. max) across q-batch lin_cost = torch.einsum( "...f,f", X[..., self.fidelity_dims], self.weights.to(X) ) return self.fixed_cost + lin_cost.unsqueeze(-1)