Source code for botorch.acquisition.penalized

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Modules to add regularization to acquisition functions.
"""

from __future__ import annotations

import math
from typing import List, Optional

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.exceptions import UnsupportedError
from torch import Tensor


[docs]class L2Penalty(torch.nn.Module): r"""L2 penalty class to be added to any arbitrary acquisition function.""" def __init__(self, init_point: Tensor): r"""Initializing L2 regularization. Args: init_point: The "1 x dim" reference point against which we want to regularize. """ super().__init__() self.init_point = init_point
[docs] def forward(self, X: Tensor) -> Tensor: r""" Args: X: A "batch_shape x q x dim" representing the points to be evaluated. Returns: A tensor of size "batch_shape" representing the acqfn for each q-batch. """ regularization_term = ( torch.norm((X - self.init_point), p=2, dim=-1).max(dim=-1).values ** 2 ) return regularization_term
[docs]class GaussianPenalty(torch.nn.Module): r"""Gaussian penalty class to be added to any arbitrary acquisition function.""" def __init__(self, init_point: Tensor, sigma: float): r"""Initializing Gaussian regularization. Args: init_point: The "1 x dim" reference point against which we want to regularize. sigma: The parameter used in gaussian function. """ super().__init__() self.init_point = init_point self.sigma = sigma
[docs] def forward(self, X: Tensor) -> Tensor: r""" Args: X: A "batch_shape x q x dim" representing the points to be evaluated. Returns: A tensor of size "batch_shape" representing the acqfn for each q-batch. """ sq_diff = torch.norm((X - self.init_point), p=2, dim=-1) ** 2 pdf = torch.exp(sq_diff / 2 / self.sigma ** 2) regularization_term = pdf.max(dim=-1).values return regularization_term
[docs]class GroupLassoPenalty(torch.nn.Module): r"""Group lasso penalty class to be added to any arbitrary acquisition function.""" def __init__(self, init_point: Tensor, groups: List[List[int]]): r"""Initializing Group-Lasso regularization. Args: init_point: The "1 x dim" reference point against which we want to regularize. groups: Groups of indices used in group lasso. """ super().__init__() self.init_point = init_point self.groups = groups
[docs] def forward(self, X: Tensor) -> Tensor: r""" X should be batch_shape x 1 x dim tensor. Evaluation for q-batch is not implemented yet. """ if X.shape[-2] != 1: raise NotImplementedError( "group-lasso has not been implemented for q>1 yet." ) regularization_term = group_lasso_regularizer( X=X.squeeze(-2) - self.init_point, groups=self.groups ) return regularization_term
[docs]class PenalizedAcquisitionFunction(AcquisitionFunction): r"""Single-outcome acquisition function regularized by the given penalty. The usage is similar to: raw_acqf = NoisyExpectedImprovement(...) penalty = GroupLassoPenalty(...) acqf = PenalizedAcquisitionFunction(raw_acqf, penalty) """ def __init__( self, raw_acqf: AcquisitionFunction, penalty_func: torch.nn.Module, regularization_parameter: float, ) -> None: r"""Initializing Group-Lasso regularization. Args: raw_acqf: The raw acquisition function that is going to be regularized. penalty_func: The regularization function. regularization_parameter: Regularization parameter used in optimization. """ super().__init__(model=raw_acqf.model) self.raw_acqf = raw_acqf self.penalty_func = penalty_func self.regularization_parameter = regularization_parameter
[docs] def forward(self, X: Tensor) -> Tensor: raw_value = self.raw_acqf(X=X) penalty_term = self.penalty_func(X) return raw_value - self.regularization_parameter * penalty_term
@property def X_pending(self) -> Optional[Tensor]: return self.raw_acqf.X_pending
[docs] def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): self.raw_acqf.set_X_pending(X_pending=X_pending) else: raise UnsupportedError( "The raw acquisition function is Analytic and does not account " "for X_pending yet." )
[docs]def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor: r"""Computes the group lasso regularization function for the given point. Args: X: A bxd tensor representing the points to evaluate the regularization at. groups: List of indices of different groups. Returns: Computed group lasso norm of at the given points. """ return torch.sum( torch.stack( [math.sqrt(len(g)) * torch.norm(X[..., g], p=2, dim=-1) for g in groups], dim=-1, ), dim=-1, )