Source code for botorch.models.utils.parse_training_data

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Parsing rules for BoTorch datasets."""

from __future__ import annotations

from typing import Any, Dict, Type

import torch
from botorch.models.model import Model
from botorch.models.pairwise_gp import PairwiseGP
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from torch import Tensor

def _encoder(arg: Any) -> Type:
    # Allow type variables to be passed as arguments at runtime
    return arg if isinstance(arg, type) else type(arg)

dispatcher = Dispatcher("parse_training_data", encoder=_encoder)

[docs]def parse_training_data( consumer: Any, training_data: SupervisedDataset, **kwargs: Any, ) -> Dict[str, Tensor]: r"""Prepares a dataset for consumption by a given object. Args: training_datas: A SupervisedDataset. consumer: The object that will consume the parsed data, or type thereof. Returns: A dictionary containing the extracted information. """ return dispatcher(consumer, training_data, **kwargs)
@dispatcher.register(Model, SupervisedDataset) def _parse_model_supervised( consumer: Model, dataset: SupervisedDataset, **ignore: Any ) -> Dict[str, Tensor]: parsed_data = {"train_X": dataset.X, "train_Y": dataset.Y} if dataset.Yvar is not None: parsed_data["train_Yvar"] = dataset.Yvar return parsed_data @dispatcher.register(PairwiseGP, RankingDataset) def _parse_pairwiseGP_ranking( consumer: PairwiseGP, dataset: RankingDataset, **ignore: Any ) -> Dict[str, Tensor]: # TODO: [T163045056] Not sure what the point of the special container is if we have # to further process it here. We should move this logic into RankingDataset. datapoints = dataset._X.values comparisons = dataset._X.indices comp_order = dataset.Y comparisons = torch.gather(input=comparisons, dim=-1, index=comp_order) return { "datapoints": datapoints, "comparisons": comparisons, }