Source code for botorch.optim.optimize_homotopy

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional, Union

import torch
from botorch.acquisition import AcquisitionFunction
from botorch.optim.homotopy import Homotopy
from botorch.optim.optimize import optimize_acqf
from torch import Tensor


[docs] def prune_candidates( candidates: Tensor, acq_values: Tensor, prune_tolerance: float ) -> Tensor: r"""Prune candidates based on their distance to other candidates. Args: candidates: An `n x d` tensor of candidates. acq_values: An `n` tensor of candidate values. prune_tolerance: The minimum distance to prune candidates. Returns: An `m x d` tensor of pruned candidates. """ if candidates.ndim != 2: raise ValueError("`candidates` must be of size `n x d`.") if acq_values.ndim != 1 or len(acq_values) != candidates.shape[0]: raise ValueError("`acq_values` must be of size `n`.") if prune_tolerance < 0: raise ValueError("`prune_tolerance` must be >= 0.") sorted_inds = acq_values.argsort(descending=True) candidates = candidates[sorted_inds] candidates_new = candidates[:1, :] for i in range(1, candidates.shape[0]): if ( torch.cdist(candidates[i : i + 1, :], candidates_new).min() > prune_tolerance ): candidates_new = torch.cat( [candidates_new, candidates[i : i + 1, :]], dim=-2 ) return candidates_new
[docs] def optimize_acqf_homotopy( acq_function: AcquisitionFunction, bounds: Tensor, q: int, homotopy: Homotopy, num_restarts: int, raw_samples: Optional[int] = None, fixed_features: Optional[dict[int, float]] = None, options: Optional[dict[str, Union[bool, float, int, str]]] = None, final_options: Optional[dict[str, Union[bool, float, int, str]]] = None, batch_initial_conditions: Optional[Tensor] = None, post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, prune_tolerance: float = 1e-4, ) -> tuple[Tensor, Tensor]: r"""Generate a set of candidates via multi-start optimization. Args: acq_function: An AcquisitionFunction. bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. q: The number of candidates. homotopy: Homotopy object that will make the necessary modifications to the problem when calling `step()`. num_restarts: The number of starting points for multistart acquisition function optimization. raw_samples: The number of samples for initialization. This is required if `batch_initial_conditions` is not specified. fixed_features: A map `{feature_index: value}` for features that should be fixed to a particular value during generation. options: Options for candidate generation. final_options: Options for candidate generation in the last homotopy step. batch_initial_conditions: A tensor to specify the initial conditions. Set this if you do not want to use default initialization strategy. post_processing_func: Post processing function (such as roundingor clamping) that is applied before choosing the final candidate. """ candidate_list, acq_value_list = [], [] if q > 1: base_X_pending = acq_function.X_pending for _ in range(q): candidates = batch_initial_conditions homotopy.restart() while not homotopy.should_stop: candidates, acq_values = optimize_acqf( q=1, acq_function=acq_function, bounds=bounds, num_restarts=num_restarts, batch_initial_conditions=candidates, raw_samples=raw_samples, fixed_features=fixed_features, return_best_only=False, options=options, ) homotopy.step() # Prune candidates candidates = prune_candidates( candidates=candidates.squeeze(1), acq_values=acq_values, prune_tolerance=prune_tolerance, ).unsqueeze(1) # Optimize one more time with the final options candidates, acq_values = optimize_acqf( q=1, acq_function=acq_function, bounds=bounds, num_restarts=num_restarts, batch_initial_conditions=candidates, return_best_only=False, options=final_options, ) # Post-process the candidates and grab the best candidate if post_processing_func is not None: candidates = post_processing_func(candidates) acq_values = acq_function(candidates) best = torch.argmax(acq_values.view(-1), dim=0) candidate, acq_value = candidates[best], acq_values[best] # Keep the new candidate and update the pending points candidate_list.append(candidate) acq_value_list.append(acq_value) selected_candidates = torch.cat(candidate_list, dim=-2) if q > 1: acq_function.set_X_pending( torch.cat([base_X_pending, selected_candidates], dim=-2) if base_X_pending is not None else selected_candidates ) if q > 1: # Reset acq_function to previous X_pending state acq_function.set_X_pending(base_X_pending) homotopy.reset() # Reset the homotopy parameters return selected_candidates, torch.stack(acq_value_list)