Source code for botorch.test_functions.base

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

Base class for test functions for optimization benchmarks.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Module

[docs]class BaseTestProblem(Module, ABC): r"""Base class for test functions.""" dim: int _bounds: List[Tuple[float, float]] _check_grad_at_opt: bool = True def __init__(self, noise_std: Optional[float] = None, negate: bool = False) -> None: r"""Base constructor for test functions. Arguments: noise_std: Standard deviation of the observation noise. negate: If True, negate the function. """ super().__init__() self.noise_std = noise_std self.negate = negate self.register_buffer( "bounds", torch.tensor(self._bounds, dtype=torch.float).transpose(-1, -2) )
[docs] def forward(self, X: Tensor, noise: bool = True) -> Tensor: r"""Evaluate the function on a set of points. Args: X: The point(s) at which to evaluate the function. noise: If `True`, add observation noise as specified by `noise_std`. """ batch = X.ndimension() > 1 X = X if batch else X.unsqueeze(0) f = self.evaluate_true(X=X) if noise and self.noise_std is not None: f += self.noise_std * torch.randn_like(f) if self.negate: f = -f return f if batch else f.squeeze(0)
[docs] @abstractmethod def evaluate_true(self, X: Tensor) -> Tensor: r"""Evaluate the function (w/o observation noise) on a set of points.""" pass # pragma: no cover