Skip to content

Commit ec2ad88

Browse files
eddiebergmanfacebook-github-bot
authored andcommitted
typing: Use Self from typing_extensions (#2494)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation Why wait till 3.11 :) The dependancy of `typing_extensions` is already present in `torch` and is a core python development library with the same license as python itself. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2494 Test Plan: I believe this only needs to pass the lint checks. However as it adds a new dependancy (which should be installed already with `torch`), there could be issues with action runners if installation of dependencies is done in a non-standard manner. [However it seems like it should be fine.](https://github.com/pytorch/botorch/blob/d52205a031ecd83b9ef73ba05ca990a284179edb/.github/workflows/test.yml#L45) ## Related PRs * Closes #2487 (issue, not PR) Reviewed By: saitcakmak Differential Revision: D61976244 Pulled By: Balandat fbshipit-source-id: be3dc67147a1bd210d336e56488f3c6c25b67939
1 parent 18eb95a commit ec2ad88

File tree

3 files changed

+10
-17
lines changed

3 files changed

+10
-17
lines changed

botorch/models/approximate_gp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import copy
3333
import warnings
3434

35-
from typing import Optional, TypeVar, Union
35+
from typing import Optional, Union
3636

3737
import torch
3838
from botorch.exceptions.warnings import UserInputWarning
@@ -68,9 +68,9 @@
6868
)
6969
from torch import Tensor
7070
from torch.nn import Module
71+
from typing_extensions import Self
7172

7273

73-
TApproxModel = TypeVar("TApproxModel", bound="ApproximateGPyTorchModel")
7474
TRANSFORM_WARNING = (
7575
"Using an {ttype} transform with `SingleTaskVariationalGP`. If this "
7676
"model is trained in minibatches, a {ttype} transform with learnable "
@@ -132,11 +132,11 @@ def __init__(
132132
def num_outputs(self):
133133
return self._desired_num_outputs
134134

135-
def eval(self: TApproxModel) -> TApproxModel:
135+
def eval(self) -> Self:
136136
r"""Puts the model in `eval` mode."""
137137
return Module.eval(self)
138138

139-
def train(self: TApproxModel, mode: bool = True) -> TApproxModel:
139+
def train(self, mode: bool = True) -> Self:
140140
r"""Put the model in `train` mode.
141141
142142
Args:

botorch/models/model.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from abc import ABC, abstractmethod
1717
from collections import defaultdict
1818
from collections.abc import Mapping
19-
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
19+
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
2020

2121
import numpy as np
2222
import torch
@@ -37,12 +37,11 @@
3737
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
3838
from torch import Tensor
3939
from torch.nn import Module, ModuleDict, ModuleList
40+
from typing_extensions import Self
4041

4142
if TYPE_CHECKING:
4243
from botorch.acquisition.objective import PosteriorTransform # pragma: no cover
4344

44-
TFantasizeMixin = TypeVar("TFantasizeMixin", bound="FantasizeMixin")
45-
4645

4746
class Model(Module, ABC):
4847
r"""Abstract base class for BoTorch models.
@@ -289,11 +288,7 @@ def __init__(self, args):
289288
"""
290289

291290
@abstractmethod
292-
def condition_on_observations(
293-
self: TFantasizeMixin,
294-
X: Tensor,
295-
Y: Tensor,
296-
) -> TFantasizeMixin:
291+
def condition_on_observations(self, X: Tensor, Y: Tensor) -> Self:
297292
"""
298293
Classes that inherit from `FantasizeMixin` must implement
299294
a `condition_on_observations` method.
@@ -322,16 +317,13 @@ def transform_inputs(
322317
a `transform_inputs` method.
323318
"""
324319

325-
# When Python 3.11 arrives we can start annotating return types like
326-
# this as
327-
# 'Self', but at this point the verbose 'T...' syntax is needed.
328320
def fantasize(
329-
self: TFantasizeMixin,
321+
self,
330322
X: Tensor,
331323
sampler: MCSampler,
332324
observation_noise: Optional[Tensor] = None,
333325
**kwargs: Any,
334-
) -> TFantasizeMixin:
326+
) -> Self:
335327
r"""Construct a fantasy model.
336328
337329
Constructs a fantasy model in the following fashion:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ scipy
33
mpmath>=0.19,<=1.3
44
torch>=2.0.1
55
pyro-ppl>=1.8.4
6+
typing_extensions
67
gpytorch==1.13
78
linear_operator==0.5.3

0 commit comments

Comments
 (0)