# Source code for botorch.acquisition.penalized

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# LICENSE file in the root directory of this source tree.

r"""
Modules to add regularization to acquisition functions.
"""

from __future__ import annotations

import math
from typing import Callable, List, Optional

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.objective import GenericMCObjective
from botorch.exceptions import UnsupportedError
from torch import Tensor

[docs]class L2Penalty(torch.nn.Module):
r"""L2 penalty class to be added to any arbitrary acquisition function
to construct a PenalizedAcquisitionFunction."""

def __init__(self, init_point: Tensor):
r"""Initializing L2 regularization.

Args:
init_point: The "1 x dim" reference point against which
we want to regularize.
"""
super().__init__()
self.init_point = init_point

[docs]    def forward(self, X: Tensor) -> Tensor:
r"""
Args:
X: A "batch_shape x q x dim" representing the points to be evaluated.

Returns:
A tensor of size "batch_shape" representing the acqfn for each q-batch.
"""
regularization_term = (
torch.norm((X - self.init_point), p=2, dim=-1).max(dim=-1).values ** 2
)
return regularization_term

[docs]class L1Penalty(torch.nn.Module):
r"""L1 penalty class to be added to any arbitrary acquisition function
to construct a PenalizedAcquisitionFunction."""

def __init__(self, init_point: Tensor):
r"""Initializing L1 regularization.

Args:
init_point: The "1 x dim" reference point against which
we want to regularize.
"""
super().__init__()
self.init_point = init_point

[docs]    def forward(self, X: Tensor) -> Tensor:
r"""
Args:
X: A "batch_shape x q x dim" representing the points to be evaluated.

Returns:
A tensor of size "batch_shape" representing the acqfn for each q-batch.
"""
regularization_term = (
torch.norm((X - self.init_point), p=1, dim=-1).max(dim=-1).values
)
return regularization_term

[docs]class GaussianPenalty(torch.nn.Module):
r"""Gaussian penalty class to be added to any arbitrary acquisition function
to construct a PenalizedAcquisitionFunction."""

def __init__(self, init_point: Tensor, sigma: float):
r"""Initializing Gaussian regularization.

Args:
init_point: The "1 x dim" reference point against which
we want to regularize.
sigma: The parameter used in gaussian function.
"""
super().__init__()
self.init_point = init_point
self.sigma = sigma

[docs]    def forward(self, X: Tensor) -> Tensor:
r"""
Args:
X: A "batch_shape x q x dim" representing the points to be evaluated.

Returns:
A tensor of size "batch_shape" representing the acqfn for each q-batch.
"""
sq_diff = torch.norm((X - self.init_point), p=2, dim=-1) ** 2
pdf = torch.exp(sq_diff / 2 / self.sigma**2)
regularization_term = pdf.max(dim=-1).values
return regularization_term

[docs]class GroupLassoPenalty(torch.nn.Module):
r"""Group lasso penalty class to be added to any arbitrary acquisition function
to construct a PenalizedAcquisitionFunction."""

def __init__(self, init_point: Tensor, groups: List[List[int]]):
r"""Initializing Group-Lasso regularization.

Args:
init_point: The "1 x dim" reference point against which we want
to regularize.
groups: Groups of indices used in group lasso.
"""
super().__init__()
self.init_point = init_point
self.groups = groups

[docs]    def forward(self, X: Tensor) -> Tensor:
r"""
X should be batch_shape x 1 x dim tensor. Evaluation for q-batch is not
implemented yet.
"""
if X.shape[-2] != 1:
raise NotImplementedError(
"group-lasso has not been implemented for q>1 yet."
)

regularization_term = group_lasso_regularizer(
X=X.squeeze(-2) - self.init_point, groups=self.groups
)
return regularization_term

[docs]class PenalizedAcquisitionFunction(AcquisitionFunction):
r"""Single-outcome acquisition function regularized by the given penalty.

The usage is similar to:
raw_acqf = NoisyExpectedImprovement(...)
penalty = GroupLassoPenalty(...)
acqf = PenalizedAcquisitionFunction(raw_acqf, penalty)
"""

def __init__(
self,
raw_acqf: AcquisitionFunction,
penalty_func: torch.nn.Module,
regularization_parameter: float,
) -> None:
r"""Initializing Group-Lasso regularization.

Args:
raw_acqf: The raw acquisition function that is going to be regularized.
penalty_func: The regularization function.
regularization_parameter: Regularization parameter used in optimization.
"""
super().__init__(model=raw_acqf.model)
self.raw_acqf = raw_acqf
self.penalty_func = penalty_func
self.regularization_parameter = regularization_parameter

[docs]    def forward(self, X: Tensor) -> Tensor:
raw_value = self.raw_acqf(X=X)
penalty_term = self.penalty_func(X)
return raw_value - self.regularization_parameter * penalty_term

@property
def X_pending(self) -> Optional[Tensor]:
return self.raw_acqf.X_pending

[docs]    def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
self.raw_acqf.set_X_pending(X_pending=X_pending)
else:
raise UnsupportedError(
"The raw acquisition function is Analytic and does not account "
"for X_pending yet."
)

[docs]def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
r"""Computes the group lasso regularization function for the given point.

Args:
X: A bxd tensor representing the points to evaluate the regularization at.
groups: List of indices of different groups.

Returns:
Computed group lasso norm of at the given points.
"""
torch.stack(
[math.sqrt(len(g)) * torch.norm(X[..., g], p=2, dim=-1) for g in groups],
dim=-1,
),
dim=-1,
)

[docs]class L1PenaltyObjective(torch.nn.Module):
r"""
L1 penalty objective class. An instance of this class can be added to any
arbitrary objective to construct a PenalizedMCObjective.
"""

def __init__(self, init_point: Tensor):
r"""Initializing L1 penalty objective.

Args:
init_point: The "1 x dim" reference point against which
we want to regularize.
"""
super().__init__()
self.init_point = init_point

[docs]    def forward(self, X: Tensor) -> Tensor:
r"""
Args:
X: A "batch_shape x q x dim" representing the points to be evaluated.

Returns:
A "1 x batch_shape x q" tensor representing the penalty for each point.
The first dimension corresponds to the dimension of MC samples.
"""

[docs]class PenalizedMCObjective(GenericMCObjective):
r"""Penalized MC objective.

Allows to construct a penaltized MC-objective by adding a penalty term to
the original objective.

mc_acq(X) = objective(X) + penalty_objective(X)

Note: PenalizedMCObjective allows adding penalty at the MCObjective level,
different from the AcquisitionFunction level in PenalizedAcquisitionFunction.

Example:
>>> regularization_parameter = 0.01
>>> init_point = torch.zeros(3) # assume data dim is 3
>>> objective = lambda Y, X: torch.sqrt(Y).sum(dim=-1)
>>> l1_penalty_objective = L1PenaltyObjective(init_point=init_point)
>>> l1_penalized_objective = PenalizedMCObjective(
objective, l1_penalty_objective, regularization_parameter
)
>>> samples = sampler(posterior)
objective, l1_penalty_objective, regularization_parameter
"""

def __init__(
self,
objective: Callable[[Tensor, Optional[Tensor]], Tensor],
penalty_objective: torch.nn.Module,
regularization_parameter: float,
) -> None:
r"""Penalized MC objective.

Args:
objective: A callable f(samples, X) mapping a
sample_shape x batch-shape x q x m-dim Tensor samples and
an optional batch-shape x q x d-dim Tensor X to a
sample_shape x batch-shape x q-dim Tensor of objective values.
penalty_objective: A torch.nn.Module f(X) that takes in a
batch-shape x q x d-dim Tensor X and outputs a
1 x batch-shape x q-dim Tensor of penalty objective values.
regularization_parameter: weight of the penalty (regularization) term
"""
super().__init__(objective=objective)
self.penalty_objective = penalty_objective
self.regularization_parameter = regularization_parameter

[docs]    def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
r"""Evaluate the penalized objective on the samples.

Args:
samples: A sample_shape x batch_shape x q x m-dim Tensors of
samples from a model posterior.
X: A batch_shape x q x d-dim tensor of inputs. Relevant only if
the objective depends on the inputs explicitly.

Returns:
A sample_shape x batch_shape x q-dim Tensor of objective values
with penalty added for each point.
"""
obj = super().forward(samples=samples, X=X)
penalty_obj = self.penalty_objective(X)
return obj - self.regularization_parameter * penalty_obj