Meta-learning with RGPE
Meta-Learning with the Rank-Weighted GP Ensemble (RGPE)
BoTorch is designed in to be model-agnostic and only requries that a model conform to a minimal interface. This tutorial walks through an example of implementing the rank-weighted Gaussian process ensemble (RGPE) [Feurer, Letham, Bakshy ICML 2018 AutoML Workshop] and using the RGPE in BoTorch to do meta-learning across related optimization tasks.
- Original paper: https://arxiv.org/pdf/1802.02219.pdf
 
# Install dependencies if we are running in colab
import sys
if 'google.colab' in sys.modules:
    %pip install botorch
import os
import torch
import math
torch.manual_seed(29)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.double
SMOKE_TEST = os.environ.get("SMOKE_TEST")
Toy Problem
- We consider optimizing the following 1-D synthetic function
 
where
is a task-dependent shift parameter and is the task index .
- 
In this tutorial, we will consider the scenario where we have collected data from 5 prior tasks (referred to as base tasks), which with a different task dependent shift parameter .
 - 
The goal now is use meta-learning to improve sample efficiency when optimizing a 6th task.
 
Toy Problem Setup
First let's define a function for compute the shift parameter and set the shift amount for the target task.
NUM_BASE_TASKS = 5 if not SMOKE_TEST else 2
def task_shift(task):
    """
    Fetch shift amount for task.
    """
    return math.pi * task / 12.0
# set shift for target task
TARGET_SHIFT = 0.0
Then, let's define our function and set bounds on .
BOUNDS = torch.tensor([[-10.0], [10.0]], dtype=dtype, device=device)
def f(X, shift=TARGET_SHIFT):
    """
    Torch-compatible objective function for the target_task
    """
    f_X = X * torch.sin(X + math.pi + shift) + X / 10.0
    return f_X
Sample training data for prior base tasks
We sample data from a Sobol sequence to help ensure numerical stability when using a small amount of 1-D data. Sobol sequences help prevent us from sampling a bunch of training points that are close together.
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import normalize, unnormalize
noise_std = 0.05
# Sample data for each base task
data_by_task = {}
for task in range(NUM_BASE_TASKS):
    num_training_points = 20
    # draw points from a sobol sequence
    raw_x = draw_sobol_samples(
        bounds=BOUNDS,
        n=num_training_points,
        q=1,
        seed=task + 5397923,
    ).squeeze(1)
    # get observed values
    f_x = f(raw_x, task_shift(task + 1))
    train_y = f_x + noise_std * torch.randn_like(f_x)
    train_yvar = torch.full_like(train_y, noise_std**2)
    # store training data
    data_by_task[task] = {
        # scale x to [0, 1]
        "train_x": normalize(raw_x, bounds=BOUNDS),
        "train_y": train_y,
        "train_yvar": train_yvar,
    }
Let's plot the base tasks and the target task function along with the observed points
from matplotlib import pyplot as plt
%matplotlib inline
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
x = torch.linspace(-10, 10, 51)
for task in data_by_task:
    # plot true function and observed values for base runs
    t = ax.plot(
        unnormalize(data_by_task[task]["train_x"], bounds=BOUNDS).cpu().numpy(),
        data_by_task[task]["train_y"].cpu().numpy(),
        ".",
        markersize=10,
        label=f"Observed task {task}",
    )
    ax.plot(
        x.detach().numpy(),
        f(x, task_shift(task + 1)).cpu().numpy(),
        label=f"Base task {task}",
        color=t[0].get_color(),
    )
# plot true target function
ax.plot(
    x.detach().numpy(),
    f(x, TARGET_SHIFT).detach().numpy(),
    "--",
    label="Target task",
)
ax.legend(loc="lower right", fontsize=10)
plt.tight_layout()
Fit base task models
First, let's define a helper function to fit a SingleTaskGP with an fixed observed noise level.
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
def get_fitted_model(train_X, train_Y, train_Yvar, state_dict=None):
    """
    Get a single task GP. The model will be fit unless a state_dict with model
        hyperparameters is provided.
    """
    model = SingleTaskGP(train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
    if state_dict is None:
        mll = ExactMarginalLogLikelihood(model.likelihood, model).to(train_X)
        fit_gpytorch_mll(mll)
    else:
        model.load_state_dict(state_dict)
    return model
Now let's fit a SingleTaskGP for each base task
# Fit base model
base_model_list = []
for task in range(NUM_BASE_TASKS):
    print(f"Fitting base model {task}")
    model = get_fitted_model(
        data_by_task[task]["train_x"],
        data_by_task[task]["train_y"],
        data_by_task[task]["train_yvar"],
    )
    base_model_list.append(model)
Fitting base model 0
Fitting base model 1
Fitting base model 2
Fitting base model 3
Fitting base model 4
Implement the RGPE
The main idea of the RGPE is to estimate the target function as weighted sum of the target model and the base models:
Importantly, the ensemble model is also a GP:
The weights for model are based on the the ranking loss between a draw from the model's posterior and the targets. Specifically, the ranking loss for model is:
where is exclusive-or.
The loss for the target model is computing using leave-one-out cross-validation (LOOCV) and is given by:
where model fitted to all data from the target task except training example .
The weights are then computed as:
def roll_col(X, shift):
    """
    Rotate columns to right by shift.
    """
    return torch.cat((X[..., -shift:], X[..., :-shift]), dim=-1)
def compute_ranking_loss(f_samps, target_y):
    """
    Compute ranking loss for each sample from the posterior over target points.
    Args:
        f_samps: `n_samples x (n) x n`-dim tensor of samples
        target_y: `n x 1`-dim tensor of targets
    Returns:
        Tensor: `n_samples`-dim tensor containing the ranking loss across each sample
    """
    n = target_y.shape[0]
    if f_samps.ndim == 3:
        # Compute ranking loss for target model
        # take cartesian product of target_y
        cartesian_y = torch.cartesian_prod(
            target_y.squeeze(-1),
            target_y.squeeze(-1),
        ).view(n, n, 2)
        # the diagonal of f_samps are the out-of-sample predictions
        # for each LOO model, compare the out of sample predictions to each in-sample prediction
        rank_loss = (
            (
                (f_samps.diagonal(dim1=1, dim2=2).unsqueeze(-1) < f_samps)
                ^ (cartesian_y[..., 0] < cartesian_y[..., 1])
            )
            .sum(dim=-1)
            .sum(dim=-1)
        )
    else:
        rank_loss = torch.zeros(
            f_samps.shape[0], dtype=torch.long, device=target_y.device
        )
        y_stack = target_y.squeeze(-1).expand(f_samps.shape)
        for i in range(1, target_y.shape[0]):
            rank_loss += (
                (roll_col(f_samps, i) < f_samps) ^ (roll_col(y_stack, i) < y_stack)
            ).sum(dim=-1)
    return rank_loss
Define a function to:
- Create a batch mode-gp LOOCV GP using the hyperparameters from 
target_model - Draw a joint sample across all points from the target task (in-sample and out-of-sample)
 
def get_target_model_loocv_sample_preds(
    train_x, train_y, train_yvar, target_model, num_samples
):
    """
    Create a batch-mode LOOCV GP and draw a joint sample across all points from the target task.
    Args:
        train_x: `n x d` tensor of training points
        train_y: `n x 1` tensor of training targets
        target_model: fitted target model
        num_samples: number of mc samples to draw
    Return: `num_samples x n x n`-dim tensor of samples, where dim=1 represents the `n` LOO models,
        and dim=2 represents the `n` training points.
    """
    batch_size = len(train_x)
    masks = torch.eye(len(train_x), dtype=torch.uint8, device=device).bool()
    train_x_cv = torch.stack([train_x[~m] for m in masks])
    train_y_cv = torch.stack([train_y[~m] for m in masks])
    train_yvar_cv = torch.stack([train_yvar[~m] for m in masks])
    state_dict = target_model.state_dict()
    # expand to batch size of batch_mode LOOCV model
    state_dict_expanded = {
        name: t.expand(batch_size, *[-1 for _ in range(t.ndim)])
        for name, t in state_dict.items()
    }
    model = get_fitted_model(
        train_x_cv, train_y_cv, train_yvar_cv, state_dict=state_dict_expanded
    )
    with torch.no_grad():
        posterior = model.posterior(train_x)
        # Since we have a batch mode gp and model.posterior always returns an output dimension,
        # the output from `posterior.sample()` here `num_samples x n x n x 1`, so let's squeeze
        # the last dimension.
        sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_samples]))
        return sampler(posterior).squeeze(-1)
def compute_rank_weights(train_x, train_y, base_models, target_model, num_samples):
    """
    Compute ranking weights for each base model and the target model (using
        LOOCV for the target model). Note: This implementation does not currently
        address weight dilution, since we only have a small number of base models.
    Args:
        train_x: `n x d` tensor of training points (for target task)
        train_y: `n` tensor of training targets (for target task)
        base_models: list of base models
        target_model: target model
        num_samples: number of mc samples
    Returns:
        Tensor: `n_t`-dim tensor with the ranking weight for each model
    """
    ranking_losses = []
    # compute ranking loss for each base model
    for task in range(len(base_models)):
        model = base_models[task]
        # compute posterior over training points for target task
        posterior = model.posterior(train_x)
        sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_samples]))
        base_f_samps = sampler(posterior).squeeze(-1).squeeze(-1)
        # compute and save ranking loss
        ranking_losses.append(compute_ranking_loss(base_f_samps, train_y))
    # compute ranking loss for target model using LOOCV
    # f_samps
    target_f_samps = get_target_model_loocv_sample_preds(
        train_x,
        train_y,
        train_yvar,
        target_model,
        num_samples,
    )
    ranking_losses.append(compute_ranking_loss(target_f_samps, train_y))
    ranking_loss_tensor = torch.stack(ranking_losses)
    # compute best model (minimum ranking loss) for each sample
    best_models = torch.argmin(ranking_loss_tensor, dim=0)
    # compute proportion of samples for which each model is best
    rank_weights = (
        best_models.bincount(minlength=len(ranking_losses)).type_as(train_x)
        / num_samples
    )
    return rank_weights
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.models import GP
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import LikelihoodList
from linear_operator.operators import PsdSumLinearOperator
from torch.nn import ModuleList
class RGPE(GP, GPyTorchModel):
    """
    Rank-weighted GP ensemble. Note: this class inherits from GPyTorchModel which provides an
        interface for GPyTorch models in botorch.
    """
    _num_outputs = 1  # metadata for botorch
    def __init__(self, models, weights):
        super().__init__()
        self.models = ModuleList(models)
        for m in models:
            if not hasattr(m, "likelihood"):
                raise ValueError(
                    "RGPE currently only supports models that have a likelihood (e.g. ExactGPs)"
                )
        self.likelihood = LikelihoodList(*[m.likelihood for m in models])
        self.weights = weights
        self.to(weights)
    def forward(self, x):
        weighted_means = []
        weighted_covars = []
        # filter model with zero weights
        # weights on covariance matrices are weight**2
        non_zero_weight_indices = (self.weights**2 > 0).nonzero()
        non_zero_weights = self.weights[non_zero_weight_indices]
        # re-normalize
        non_zero_weights /= non_zero_weights.sum()
        for non_zero_weight_idx in range(non_zero_weight_indices.shape[0]):
            raw_idx = non_zero_weight_indices[non_zero_weight_idx].item()
            model = self.models[raw_idx]
            posterior = model.posterior(x)
            # unstandardize predictions
            posterior_mean = posterior.mean.squeeze(-1)
            posterior_cov = posterior.mvn.lazy_covariance_matrix
            # apply weight
            weight = non_zero_weights[non_zero_weight_idx]
            weighted_means.append(weight * posterior_mean)
            weighted_covars.append(posterior_cov * weight**2)
        # set mean and covariance to be the rank-weighted sum the means and covariances of the
        # base models and target model
        mean_x = torch.stack(weighted_means).sum(dim=0)
        covar_x = PsdSumLinearOperator(*weighted_covars)
        return MultivariateNormal(mean_x, covar_x)
Optimize target function using RGPE + qNEI
# suppress GPyTorch warnings about adding jitter
import warnings
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from botorch.optim.optimize import optimize_acqf
from botorch.sampling.normal import SobolQMCNormalSampler
warnings.filterwarnings("ignore", "^.*jitter.*", category=RuntimeWarning)
best_rgpe_all = []
best_random_all = []
best_vanilla_nei_all = []
N_BATCH = 10 if not SMOKE_TEST else 2
NUM_POSTERIOR_SAMPLES = 256 if not SMOKE_TEST else 16
RANDOM_INITIALIZATION_SIZE = 3
N_TRIALS = 10 if not SMOKE_TEST else 2
MC_SAMPLES = 512 if not SMOKE_TEST else 32
N_RESTART_CANDIDATES = 512 if not SMOKE_TEST else 8
N_RESTARTS = 10 if not SMOKE_TEST else 2
Q_BATCH_SIZE = 1
# Average over multiple trials
for trial in range(N_TRIALS):
    print(f"Trial {trial + 1} of {N_TRIALS}")
    best_rgpe = []
    best_random = []
    best_vanilla_nei = []
    # Initial random observations
    raw_x = draw_sobol_samples(
        bounds=BOUNDS, n=RANDOM_INITIALIZATION_SIZE, q=1, seed=trial
    ).squeeze(1)
    train_x = normalize(raw_x, bounds=BOUNDS)
    train_y_noiseless = f(raw_x)
    train_y = train_y_noiseless + noise_std * torch.randn_like(train_y_noiseless)
    train_yvar = torch.full_like(train_y, noise_std**2)
    vanilla_nei_train_x = train_x.clone()
    vanilla_nei_train_y = train_y.clone()
    vanilla_nei_train_yvar = train_yvar.clone()
    # keep track of the best observed point at each iteration
    best_value = train_y.max().item()
    best_rgpe.append(best_value)
    best_random.append(best_value)
    vanilla_nei_best_value = best_value
    best_vanilla_nei.append(vanilla_nei_best_value)
    # Run N_BATCH rounds of BayesOpt after the initial random batch
    for iteration in range(N_BATCH):
        target_model = get_fitted_model(train_x, train_y, train_yvar)
        model_list = base_model_list + [target_model]
        rank_weights = compute_rank_weights(
            train_x,
            train_y,
            base_model_list,
            target_model,
            NUM_POSTERIOR_SAMPLES,
        )
        # create model and acquisition function
        rgpe_model = RGPE(model_list, rank_weights)
        sampler_qnei = SobolQMCNormalSampler(sample_shape=torch.Size([MC_SAMPLES]))
        qNEI = qLogNoisyExpectedImprovement(
            model=rgpe_model,
            X_baseline=train_x,
            sampler=sampler_qnei,
            prune_baseline=False,
        )
        # optimize
        candidate, _ = optimize_acqf(
            acq_function=qNEI,
            bounds=torch.tensor([[0.0], [1.0]], dtype=dtype, device=device),
            q=Q_BATCH_SIZE,
            num_restarts=N_RESTARTS,
            raw_samples=N_RESTART_CANDIDATES,
        )
        # fetch the new values
        new_x = candidate.detach()
        new_y_noiseless = f(unnormalize(new_x, bounds=BOUNDS))
        new_y = new_y_noiseless + noise_std * torch.randn_like(new_y_noiseless)
        new_yvar = torch.full_like(new_y, noise_std**2)
        # update training points
        train_x = torch.cat((train_x, new_x))
        train_y = torch.cat((train_y, new_y))
        train_yvar = torch.cat((train_yvar, new_yvar))
        random_candidate = torch.rand(1, dtype=dtype, device=device)
        next_random_noiseless = f(unnormalize(random_candidate, bounds=BOUNDS))
        next_random = next_random_noiseless + noise_std * torch.randn_like(
            next_random_noiseless
        )
        next_random_best = next_random.max().item()
        best_random.append(max(best_random[-1], next_random_best))
        # get the new best observed value
        best_value = train_y.max().item()
        best_rgpe.append(best_value)
        # Run Vanilla NEI for comparison
        vanilla_nei_model = get_fitted_model(
            vanilla_nei_train_x,
            vanilla_nei_train_y,
            vanilla_nei_train_yvar,
        )
        vanilla_nei_sampler = SobolQMCNormalSampler(
            sample_shape=torch.Size([MC_SAMPLES])
        )
        vanilla_qNEI = qLogNoisyExpectedImprovement(
            model=vanilla_nei_model,
            X_baseline=vanilla_nei_train_x,
            sampler=vanilla_nei_sampler,
        )
        vanilla_nei_candidate, _ = optimize_acqf(
            acq_function=vanilla_qNEI,
            bounds=torch.tensor([[0.0], [1.0]], dtype=dtype, device=device),
            q=Q_BATCH_SIZE,
            num_restarts=N_RESTARTS,
            raw_samples=N_RESTART_CANDIDATES,
        )
        # fetch the new values
        vanilla_nei_new_x = vanilla_nei_candidate.detach()
        vanilla_nei_new_y_noiseless = f(unnormalize(vanilla_nei_new_x, bounds=BOUNDS))
        vanilla_nei_new_y = vanilla_nei_new_y_noiseless + noise_std * torch.randn_like(
            new_y_noiseless
        )
        vanilla_nei_new_yvar = torch.full_like(vanilla_nei_new_y, noise_std**2)
        # update training points
        vanilla_nei_train_x = torch.cat([vanilla_nei_train_x, vanilla_nei_new_x])
        vanilla_nei_train_y = torch.cat([vanilla_nei_train_y, vanilla_nei_new_y])
        vanilla_nei_train_yvar = torch.cat(
            [vanilla_nei_train_yvar, vanilla_nei_new_yvar]
        )
        # get the new best observed value
        vanilla_nei_best_value = vanilla_nei_train_y.max().item()
        best_vanilla_nei.append(vanilla_nei_best_value)
    best_rgpe_all.append(best_rgpe)
    best_random_all.append(best_random)
    best_vanilla_nei_all.append(best_vanilla_nei)
Trial 1 of 10
Trial 2 of 10
Trial 3 of 10
Trial 4 of 10
Trial 5 of 10
Trial 6 of 10
/Users/saitcakmak/botorch/botorch/optim/optimize.py:648: RuntimeWarning: Optimization failed in gen_candidates_scipy with the following warning(s):
[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]
Trying again with a new set of initial conditions.
return _optimize_acqf_batch(opt_inputs=opt_inputs)
/Users/saitcakmak/botorch/botorch/optim/optimize.py:648: RuntimeWarning: Optimization failed on the second try, after generating a new set of initial conditions.
return _optimize_acqf_batch(opt_inputs=opt_inputs)
/Users/saitcakmak/botorch/botorch/optim/optimize.py:648: RuntimeWarning: Optimization failed in gen_candidates_scipy with the following warning(s):
[BotorchWarning('Low-rank cholesky updates failed due NaNs or due to an ill-conditioned covariance matrix. Falling back to standard sampling.'), OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.'), BotorchWarning('Low-rank cholesky updates failed due NaNs or due to an ill-conditioned covariance matrix. Falling back to standard sampling.')]
Trying again with a new set of initial conditions.
return _optimize_acqf_batch(opt_inputs=opt_inputs)
Trial 7 of 10
/Users/saitcakmak/botorch/botorch/optim/optimize.py:648: RuntimeWarning: Optimization failed in gen_candidates_scipy with the following warning(s):
[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]
Trying again with a new set of initial conditions.
return _optimize_acqf_batch(opt_inputs=opt_inputs)
/Users/saitcakmak/botorch/botorch/optim/optimize.py:648: RuntimeWarning: Optimization failed on the second try, after generating a new set of initial conditions.
return _optimize_acqf_batch(opt_inputs=opt_inputs)
Trial 8 of 10
/Users/saitcakmak/botorch/botorch/optim/optimize.py:648: RuntimeWarning: Optimization failed in gen_candidates_scipy with the following warning(s):
[OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]
Trying again with a new set of initial conditions.
return _optimize_acqf_batch(opt_inputs=opt_inputs)
/Users/saitcakmak/botorch/botorch/optim/optimize.py:648: RuntimeWarning: Optimization failed on the second try, after generating a new set of initial conditions.
return _optimize_acqf_batch(opt_inputs=opt_inputs)
Trial 9 of 10
Trial 10 of 10
Trial 8 of 10
Trial 9 of 10
[W 240829 09:21:46 optimize:564] Optimization failed in gen_candidates_scipy with the following warning(s):
  [OptimizationWarning('Optimization failed within scipy.optimize.minimize with status 2 and message ABNORMAL_TERMINATION_IN_LNSRCH.')]
  Trying again with a new set of initial conditions.
[W 240829 09:21:46 optimize:564] Optimization failed on the second try, after generating a new set of initial conditions.
Trial 10 of 10
Plot best observed value vs iteration
import numpy as np
best_rgpe_all = np.array(best_rgpe_all)
best_random_all = np.array(best_random_all)
best_vanilla_nei_all = np.array(best_vanilla_nei_all)
x = range(RANDOM_INITIALIZATION_SIZE, RANDOM_INITIALIZATION_SIZE + N_BATCH + 1)
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
# Plot RGPE - LogNEI
ax.errorbar(
    x,
    best_rgpe_all.mean(axis=0),
    yerr=1.96 * best_rgpe_all.std(axis=0) / math.sqrt(N_TRIALS),
    label="RGPE - LogNEI",
    linewidth=3,
    capsize=5,
    capthick=3,
)
# Plot SingleTaskGP - LogNEI
ax.errorbar(
    x,
    best_vanilla_nei_all.mean(axis=0),
    yerr=1.96 * best_vanilla_nei_all.std(axis=0) / math.sqrt(N_TRIALS),
    label="SingleTaskGP - LogNEI",
    linewidth=3,
    capsize=5,
    capthick=3,
)
# Plot Random
ax.errorbar(
    x,
    best_random_all.mean(axis=0),
    yerr=1.96 * best_random_all.std(axis=0) / math.sqrt(N_TRIALS),
    label="Random",
    linewidth=3,
    capsize=5,
    capthick=3,
)
ax.set_ylim(bottom=0)
ax.set_xlabel("Iteration", fontsize=12)
ax.set_ylabel("Best Observed Value", fontsize=12)
ax.set_title("Best Observed Value by Iteration", fontsize=12)
ax.legend(loc="lower right", fontsize=10)
plt.tight_layout()