Skip to content

Commit 7b803bd

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add normalization to warping function (#2692)
Summary: Pull Request resolved: #2692 X-link: facebook/Ax#3259 Warping only works on the unit cube. This ensures that inputs are first normalized. This doesn't use `ChainedInputTranform` to avoid nested `ChainedInputTransform` This is to support linear+warping models in MBM when we aren't using `UnitX`. We’d want to 1) ensure data is in the unit cube before applying `Warp`, 2) then center the warped data at 0 (using Normalize). One way to do this would to apply `Normalize(center=0.5`), `Warp`, `Normalize(center=0.0)`, but we can’t currently specify different options for two different transforms of the same class. So this insteads takes an approach suggested by saitcakmak to include normalization in the `Warp` transform, since we always want inputs to be in the unit cube before warping. Reviewed By: saitcakmak, Balandat Differential Revision: D68356342 fbshipit-source-id: dbd6682d2aacf779c6c813e1a7b779c04f290af0
1 parent 6b75672 commit 7b803bd

File tree

4 files changed

+98
-38
lines changed

4 files changed

+98
-38
lines changed

botorch/models/transforms/input.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,7 @@ class Warp(ReversibleInputTransform, GPyTorchModule):
10711071

10721072
def __init__(
10731073
self,
1074+
d: int,
10741075
indices: list[int],
10751076
transform_on_train: bool = True,
10761077
transform_on_eval: bool = True,
@@ -1080,6 +1081,7 @@ def __init__(
10801081
concentration1_prior: Prior | None = None,
10811082
concentration0_prior: Prior | None = None,
10821083
batch_shape: torch.Size | None = None,
1084+
bounds: Tensor | None = None,
10831085
) -> None:
10841086
r"""Initialize transform.
10851087
@@ -1102,6 +1104,7 @@ def __init__(
11021104
parameters for each batch of inputs. This should match the input batch
11031105
shape of the model (i.e., `train_X.shape[:-2]`).
11041106
NOTE: This is only supported for single-output models.
1107+
bounds: A `2 x d`-dim tensor of lower and upper bounds for the inputs.
11051108
"""
11061109
super().__init__()
11071110
self.register_buffer("indices", torch.tensor(indices, dtype=torch.long))
@@ -1112,6 +1115,9 @@ def __init__(
11121115
self.batch_shape = batch_shape or torch.Size([])
11131116
self._X_min = eps
11141117
self._X_range = 1 - 2 * eps
1118+
self._normalize = Normalize(
1119+
d=d, indices=indices, bounds=bounds, batch_shape=self.batch_shape
1120+
)
11151121
if len(self.batch_shape) > 0:
11161122
# Note: this follows the gpytorch shape convention for lengthscales
11171123
# There is ongoing discussion about the extra `1`.
@@ -1156,7 +1162,7 @@ def _set_concentration(self, i: int, value: float | Tensor) -> None:
11561162
self.initialize(**{f"concentration{i}": value})
11571163

11581164
@subset_transform
1159-
def _transform(self, X: Tensor) -> Tensor:
1165+
def _warp_transform(self, X: Tensor) -> Tensor:
11601166
r"""Warp the inputs through the Kumaraswamy CDF.
11611167
11621168
Args:
@@ -1165,10 +1171,9 @@ def _transform(self, X: Tensor) -> Tensor:
11651171
it is broadcastable with self.batch_shape if self.batch_shape is set.
11661172
11671173
Returns:
1168-
A `input_batch_shape x (batch_shape) x n x d`-dim tensor of transformed
1169-
inputs.
1174+
A `input_batch_shape x (batch_shape) x n x d`-dim tensor
1175+
of transformed inputs.
11701176
"""
1171-
# normalize to [eps, 1-eps], IDEA: could use Normalize and ChainedTransform.
11721177
return self._k.cdf(
11731178
torch.clamp(
11741179
X * self._X_range + self._X_min,
@@ -1177,7 +1182,23 @@ def _transform(self, X: Tensor) -> Tensor:
11771182
)
11781183
)
11791184

1180-
@subset_transform
1185+
def _transform(self, X: Tensor) -> Tensor:
1186+
r"""Warp the inputs through the Kumaraswamy CDF.
1187+
1188+
Args:
1189+
X: A `input_batch_shape x (batch_shape) x n x d`-dim tensor of inputs.
1190+
batch_shape here can either be self.batch_shape or 1's such that
1191+
it is broadcastable with self.batch_shape if self.batch_shape is set.
1192+
1193+
Returns:
1194+
A `input_batch_shape x (batch_shape) x n x d`-dim tensor of transformed
1195+
inputs.
1196+
"""
1197+
# Normalize to unit cube
1198+
X = self._normalize(X=X)
1199+
# normalize to [eps, 1-eps], IDEA: could use Normalize and ChainedTransform.
1200+
return self._warp_transform(X=X)
1201+
11811202
def _untransform(self, X: Tensor) -> Tensor:
11821203
r"""Warp the inputs through the Kumaraswamy inverse CDF.
11831204
@@ -1194,6 +1215,20 @@ def _untransform(self, X: Tensor) -> Tensor:
11941215
"The right most batch dims of X must match self.batch_shape: "
11951216
f"({self.batch_shape})."
11961217
)
1218+
untransformed_X = self._warp_untransform(X=X)
1219+
return self._normalize.untransform(X=untransformed_X)
1220+
1221+
@subset_transform
1222+
def _warp_untransform(self, X: Tensor) -> Tensor:
1223+
r"""Warp the inputs through the Kumaraswamy inverse CDF.
1224+
1225+
Args:
1226+
X: A `input_batch_shape x batch_shape x n x d`-dim tensor of inputs.
1227+
1228+
Returns:
1229+
A `input_batch_shape x batch_shape x n x d`-dim tensor of transformed
1230+
inputs.
1231+
"""
11971232
# unnormalize from [eps, 1-eps] to [0,1]
11981233
return ((self._k.icdf(X) - self._X_min) / self._X_range).clamp(0.0, 1.0)
11991234

test/models/transforms/test_input.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@
4141
from torch.nn.functional import one_hot
4242

4343

44-
def get_test_warp(indices, **kwargs):
45-
warp_tf = Warp(indices=indices, **kwargs)
44+
def get_test_warp(d, indices, bounds=None, **kwargs):
45+
if bounds is None:
46+
bounds = torch.zeros(2, d)
47+
bounds[1] = 1
48+
warp_tf = Warp(d=d, indices=indices, bounds=bounds, **kwargs)
4649
c0 = torch.tensor([1.0, 2.0])[: len(indices)]
4750
c1 = torch.tensor([2.0, 3.0])[: len(indices)]
4851
batch_shape = kwargs.get("batch_shape", torch.Size([]))
@@ -1031,12 +1034,17 @@ def test_warp_transform(self) -> None:
10311034
):
10321035
tkwargs = {"device": self.device, "dtype": dtype}
10331036
eps = 1e-6 if dtype == torch.double else 1e-5
1037+
if dtype == torch.float32:
1038+
# defaults are 1e-5, 1e-8
1039+
tols = {"rtol": 2e-5, "atol": 8e-8}
1040+
else:
1041+
tols = {}
10341042

10351043
# basic init
10361044
indices = [0, 2]
1037-
warp_tf = get_test_warp(indices, batch_shape=warp_batch_shape, eps=eps).to(
1038-
**tkwargs
1039-
)
1045+
warp_tf = get_test_warp(
1046+
d=3, indices=indices, batch_shape=warp_batch_shape, eps=eps
1047+
).to(**tkwargs)
10401048
self.assertTrue(warp_tf.training)
10411049

10421050
k = Kumaraswamy(warp_tf.concentration1, warp_tf.concentration0)
@@ -1049,7 +1057,7 @@ def test_warp_transform(self) -> None:
10491057
X = X.unsqueeze(-3) if len(warp_batch_shape) > 0 else X
10501058
with torch.no_grad():
10511059
warp_tf = get_test_warp(
1052-
indices=indices, batch_shape=warp_batch_shape, eps=eps
1060+
d=3, indices=indices, batch_shape=warp_batch_shape, eps=eps
10531061
).to(**tkwargs)
10541062
X_tf = warp_tf(X)
10551063
expected_X_tf = expand_and_copy_tensor(
@@ -1077,7 +1085,8 @@ def test_warp_transform(self) -> None:
10771085

10781086
# test no transform on eval
10791087
warp_tf = get_test_warp(
1080-
indices,
1088+
d=3,
1089+
indices=indices,
10811090
transform_on_eval=False,
10821091
batch_shape=warp_batch_shape,
10831092
eps=eps,
@@ -1090,6 +1099,7 @@ def test_warp_transform(self) -> None:
10901099

10911100
# test no transform on train
10921101
warp_tf = get_test_warp(
1102+
d=3,
10931103
indices=indices,
10941104
transform_on_train=False,
10951105
batch_shape=warp_batch_shape,
@@ -1103,6 +1113,7 @@ def test_warp_transform(self) -> None:
11031113

11041114
# test equals
11051115
warp_tf2 = get_test_warp(
1116+
d=3,
11061117
indices=indices,
11071118
transform_on_train=False,
11081119
batch_shape=warp_batch_shape,
@@ -1111,11 +1122,12 @@ def test_warp_transform(self) -> None:
11111122
self.assertTrue(warp_tf.equals(warp_tf2))
11121123
# test different transform_on_train
11131124
warp_tf2 = get_test_warp(
1114-
indices=indices, batch_shape=warp_batch_shape, eps=eps
1125+
d=3, indices=indices, batch_shape=warp_batch_shape, eps=eps
11151126
)
11161127
self.assertFalse(warp_tf.equals(warp_tf2))
11171128
# test different indices
11181129
warp_tf2 = get_test_warp(
1130+
d=3,
11191131
indices=[0, 1],
11201132
transform_on_train=False,
11211133
batch_shape=warp_batch_shape,
@@ -1137,6 +1149,7 @@ def test_warp_transform(self) -> None:
11371149
prior0 = LogNormalPrior(0.0, 0.75).to(**tkwargs)
11381150
prior1 = LogNormalPrior(0.0, 0.5).to(**tkwargs)
11391151
warp_tf = get_test_warp(
1152+
d=3,
11401153
indices=[0, 1],
11411154
concentration0_prior=prior0,
11421155
concentration1_prior=prior1,
@@ -1148,11 +1161,23 @@ def test_warp_transform(self) -> None:
11481161
self.assertIsInstance(p, LogNormalPrior)
11491162
self.assertEqual(p.base_dist.scale, 0.75 if i == 0 else 0.5)
11501163

1164+
# test non-unit cube bounds
1165+
warp_tf = get_test_warp(
1166+
d=3,
1167+
indices=[0, 2],
1168+
eps=eps,
1169+
batch_shape=warp_batch_shape,
1170+
bounds=torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], **tkwargs),
1171+
).to(**tkwargs)
1172+
X[..., indices] += 1
1173+
X_tf = warp_tf(X)
1174+
self.assertAllClose(expected_X_tf, X_tf, **tols)
1175+
11511176
# test gradients
11521177
X = 1 + 5 * torch.rand(*batch_shape, 4, 3, **tkwargs)
11531178
X = X.unsqueeze(-3) if len(warp_batch_shape) > 0 else X
11541179
warp_tf = get_test_warp(
1155-
indices=indices, batch_shape=warp_batch_shape, eps=eps
1180+
d=3, indices=indices, batch_shape=warp_batch_shape, eps=eps
11561181
).to(**tkwargs)
11571182
X_tf = warp_tf(X)
11581183
X_tf.sum().backward()

test/optim/utils/test_acquisition_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@ def test_get_X_baseline(self):
240240
# to the train_inputs when the model is in eval mode, we
241241
# extract the untransformed train_inputs
242242
model = SingleTaskGP(
243-
X_train, Y_train[:, :1], input_transform=Warp(indices=[0, 1])
243+
X_train,
244+
Y_train[:, :1],
245+
input_transform=Warp(d=X_train.shape[-1], indices=[0, 1]),
244246
)
245247
model.eval()
246248
self.assertFalse(torch.equal(model.train_inputs[0], X_train))

tutorials/bo_with_warped_gp.ipynb

Lines changed: 21 additions & 23 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)