# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, Optional, Tuple, Union
import torch
from botorch.acquisition import AcquisitionFunction
from botorch.optim.homotopy import Homotopy
from botorch.optim.optimize import optimize_acqf
from torch import Tensor
[docs]
def prune_candidates(
candidates: Tensor, acq_values: Tensor, prune_tolerance: float
) -> Tensor:
r"""Prune candidates based on their distance to other candidates.
Args:
candidates: An `n x d` tensor of candidates.
acq_values: An `n` tensor of candidate values.
prune_tolerance: The minimum distance to prune candidates.
Returns:
An `m x d` tensor of pruned candidates.
"""
if candidates.ndim != 2:
raise ValueError("`candidates` must be of size `n x d`.")
if acq_values.ndim != 1 or len(acq_values) != candidates.shape[0]:
raise ValueError("`acq_values` must be of size `n`.")
if prune_tolerance < 0:
raise ValueError("`prune_tolerance` must be >= 0.")
sorted_inds = acq_values.argsort(descending=True)
candidates = candidates[sorted_inds]
candidates_new = candidates[:1, :]
for i in range(1, candidates.shape[0]):
if (
torch.cdist(candidates[i : i + 1, :], candidates_new).min()
> prune_tolerance
):
candidates_new = torch.cat(
[candidates_new, candidates[i : i + 1, :]], dim=-2
)
return candidates_new
[docs]
def optimize_acqf_homotopy(
acq_function: AcquisitionFunction,
bounds: Tensor,
q: int,
homotopy: Homotopy,
num_restarts: int,
raw_samples: Optional[int] = None,
fixed_features: Optional[Dict[int, float]] = None,
options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
final_options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
batch_initial_conditions: Optional[Tensor] = None,
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
prune_tolerance: float = 1e-4,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of candidates via multi-start optimization.
Args:
acq_function: An AcquisitionFunction.
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
q: The number of candidates.
homotopy: Homotopy object that will make the necessary modifications to the
problem when calling `step()`.
num_restarts: The number of starting points for multistart acquisition
function optimization.
raw_samples: The number of samples for initialization. This is required
if `batch_initial_conditions` is not specified.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
options: Options for candidate generation.
final_options: Options for candidate generation in the last homotopy step.
batch_initial_conditions: A tensor to specify the initial conditions. Set
this if you do not want to use default initialization strategy.
post_processing_func: Post processing function (such as roundingor clamping)
that is applied before choosing the final candidate.
"""
candidate_list, acq_value_list = [], []
if q > 1:
base_X_pending = acq_function.X_pending
for _ in range(q):
candidates = batch_initial_conditions
homotopy.restart()
while not homotopy.should_stop:
candidates, acq_values = optimize_acqf(
q=1,
acq_function=acq_function,
bounds=bounds,
num_restarts=num_restarts,
batch_initial_conditions=candidates,
raw_samples=raw_samples,
fixed_features=fixed_features,
return_best_only=False,
options=options,
)
homotopy.step()
# Prune candidates
candidates = prune_candidates(
candidates=candidates.squeeze(1),
acq_values=acq_values,
prune_tolerance=prune_tolerance,
).unsqueeze(1)
# Optimize one more time with the final options
candidates, acq_values = optimize_acqf(
q=1,
acq_function=acq_function,
bounds=bounds,
num_restarts=num_restarts,
batch_initial_conditions=candidates,
return_best_only=False,
options=final_options,
)
# Post-process the candidates and grab the best candidate
if post_processing_func is not None:
candidates = post_processing_func(candidates)
acq_values = acq_function(candidates)
best = torch.argmax(acq_values.view(-1), dim=0)
candidate, acq_value = candidates[best], acq_values[best]
# Keep the new candidate and update the pending points
candidate_list.append(candidate)
acq_value_list.append(acq_value)
selected_candidates = torch.cat(candidate_list, dim=-2)
if q > 1:
acq_function.set_X_pending(
torch.cat([base_X_pending, selected_candidates], dim=-2)
if base_X_pending is not None
else selected_candidates
)
if q > 1: # Reset acq_function to previous X_pending state
acq_function.set_X_pending(base_X_pending)
homotopy.reset() # Reset the homotopy parameters
return selected_candidates, torch.stack(acq_value_list)