Source code for botorch.optim.optimize_homotopy

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, Optional, Tuple, Union

import torch
from botorch.acquisition import AcquisitionFunction
from botorch.optim.homotopy import Homotopy
from botorch.optim.optimize import optimize_acqf
from torch import Tensor


[docs] def prune_candidates( candidates: Tensor, acq_values: Tensor, prune_tolerance: float ) -> Tensor: r"""Prune candidates based on their distance to other candidates. Args: candidates: An `n x d` tensor of candidates. acq_values: An `n` tensor of candidate values. prune_tolerance: The minimum distance to prune candidates. Returns: An `m x d` tensor of pruned candidates. """ if candidates.ndim != 2: raise ValueError("`candidates` must be of size `n x d`.") if acq_values.ndim != 1 or len(acq_values) != candidates.shape[0]: raise ValueError("`acq_values` must be of size `n`.") if prune_tolerance < 0: raise ValueError("`prune_tolerance` must be >= 0.") sorted_inds = acq_values.argsort(descending=True) candidates = candidates[sorted_inds] candidates_new = candidates[:1, :] for i in range(1, candidates.shape[0]): if ( torch.cdist(candidates[i : i + 1, :], candidates_new).min() > prune_tolerance ): candidates_new = torch.cat( [candidates_new, candidates[i : i + 1, :]], dim=-2 ) return candidates_new
[docs] def optimize_acqf_homotopy( acq_function: AcquisitionFunction, bounds: Tensor, q: int, homotopy: Homotopy, num_restarts: int, raw_samples: Optional[int] = None, fixed_features: Optional[Dict[int, float]] = None, options: Optional[Dict[str, Union[bool, float, int, str]]] = None, final_options: Optional[Dict[str, Union[bool, float, int, str]]] = None, batch_initial_conditions: Optional[Tensor] = None, post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, prune_tolerance: float = 1e-4, ) -> Tuple[Tensor, Tensor]: r"""Generate a set of candidates via multi-start optimization. Args: acq_function: An AcquisitionFunction. bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. q: The number of candidates. homotopy: Homotopy object that will make the necessary modifications to the problem when calling `step()`. num_restarts: The number of starting points for multistart acquisition function optimization. raw_samples: The number of samples for initialization. This is required if `batch_initial_conditions` is not specified. fixed_features: A map `{feature_index: value}` for features that should be fixed to a particular value during generation. options: Options for candidate generation. final_options: Options for candidate generation in the last homotopy step. batch_initial_conditions: A tensor to specify the initial conditions. Set this if you do not want to use default initialization strategy. post_processing_func: Post processing function (such as roundingor clamping) that is applied before choosing the final candidate. """ candidate_list, acq_value_list = [], [] if q > 1: base_X_pending = acq_function.X_pending for _ in range(q): candidates = batch_initial_conditions homotopy.restart() while not homotopy.should_stop: candidates, acq_values = optimize_acqf( q=1, acq_function=acq_function, bounds=bounds, num_restarts=num_restarts, batch_initial_conditions=candidates, raw_samples=raw_samples, fixed_features=fixed_features, return_best_only=False, options=options, ) homotopy.step() # Prune candidates candidates = prune_candidates( candidates=candidates.squeeze(1), acq_values=acq_values, prune_tolerance=prune_tolerance, ).unsqueeze(1) # Optimize one more time with the final options candidates, acq_values = optimize_acqf( q=1, acq_function=acq_function, bounds=bounds, num_restarts=num_restarts, batch_initial_conditions=candidates, return_best_only=False, options=final_options, ) # Post-process the candidates and grab the best candidate if post_processing_func is not None: candidates = post_processing_func(candidates) acq_values = acq_function(candidates) best = torch.argmax(acq_values.view(-1), dim=0) candidate, acq_value = candidates[best], acq_values[best] # Keep the new candidate and update the pending points candidate_list.append(candidate) acq_value_list.append(acq_value) selected_candidates = torch.cat(candidate_list, dim=-2) if q > 1: acq_function.set_X_pending( torch.cat([base_X_pending, selected_candidates], dim=-2) if base_X_pending is not None else selected_candidates ) if q > 1: # Reset acq_function to previous X_pending state acq_function.set_X_pending(base_X_pending) homotopy.reset() # Reset the homotopy parameters return selected_candidates, torch.stack(acq_value_list)