Skip to content

Commit 7c5c7cb

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add center argument to Normalize (#2680)
Summary: Pull Request resolved: #2680 This allows one to specify a center other than 0.5 for the normalized data. E.g. for GPs with linear kernels, the inputs should be centered at 0. See discussion on cornellius-gp/gpytorch#2617 (comment). Reviewed By: Balandat Differential Revision: D68293784 fbshipit-source-id: d94005fbc2a2c59c936619a5f458d1d7618b50af
1 parent a9c5eff commit 7c5c7cb

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

botorch/models/transforms/input.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def _update_coefficients(self, X: Tensor) -> None:
614614

615615

616616
class Normalize(AffineInputTransform):
617-
r"""Normalize the inputs to the unit cube.
617+
r"""Normalize the inputs have unit range and be centered at 0.5 (by default).
618618
619619
If no explicit bounds are provided this module is stateful: If in train mode,
620620
calling `forward` updates the module state (i.e. the normalizing bounds). If
@@ -635,6 +635,7 @@ def __init__(
635635
min_range: float = 1e-8,
636636
learn_bounds: bool | None = None,
637637
almost_zero: float = 1e-12,
638+
center: float = 0.5,
638639
) -> None:
639640
r"""Normalize the inputs to the unit cube.
640641
@@ -662,6 +663,7 @@ def __init__(
662663
NOTE: This only applies if `learn_bounds=True`.
663664
learn_bounds: Whether to learn the bounds in train mode. Defaults
664665
to False if bounds are provided, otherwise defaults to True.
666+
center: The center of the range for each parameter. Default: 0.5.
665667
666668
Example:
667669
>>> t = Normalize(d=2)
@@ -704,10 +706,11 @@ def __init__(
704706
"will not be updated and the transform will be a no-op.",
705707
UserInputWarning,
706708
)
709+
self.center = center
707710
super().__init__(
708711
d=d,
709712
coefficient=coefficient,
710-
offset=offset,
713+
offset=offset + (0.5 - center) * coefficient,
711714
indices=indices,
712715
batch_shape=batch_shape,
713716
transform_on_train=transform_on_train,
@@ -745,7 +748,10 @@ def _update_coefficients(self, X) -> None:
745748
coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - offset
746749
almost_zero = coefficient < self.min_range
747750
self._coefficient = torch.where(almost_zero, 1.0, coefficient)
748-
self._offset = torch.where(almost_zero, 0.0, offset)
751+
self._offset = (
752+
torch.where(almost_zero, 0.0, offset)
753+
+ (0.5 - self.center) * self._coefficient
754+
)
749755

750756
def get_init_args(self) -> dict[str, Any]:
751757
r"""Get the arguments necessary to construct an exact copy of the transform."""

test/models/transforms/test_input.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import itertools
88
from abc import ABC
99
from copy import deepcopy
10+
from itertools import product
1011
from random import randint
1112

1213
import torch
@@ -259,17 +260,19 @@ def test_normalize(self) -> None:
259260
nlz(X)
260261

261262
# basic usage
262-
for batch_shape in (torch.Size(), torch.Size([3])):
263+
for batch_shape, center in product(
264+
(torch.Size(), torch.Size([3])), [0.5, 0.0]
265+
):
263266
# learned bounds
264-
nlz = Normalize(d=2, batch_shape=batch_shape)
267+
nlz = Normalize(d=2, batch_shape=batch_shape, center=center)
265268
X = torch.randn(*batch_shape, 4, 2, device=self.device, dtype=dtype)
266269
for _X in (torch.stack((X, X)), X): # check batch_shape is obeyed
267270
X_nlzd = nlz(_X)
268271
self.assertEqual(nlz.mins.shape, batch_shape + (1, X.shape[-1]))
269272
self.assertEqual(nlz.ranges.shape, batch_shape + (1, X.shape[-1]))
270273

271-
self.assertEqual(X_nlzd.min().item(), 0.0)
272-
self.assertEqual(X_nlzd.max().item(), 1.0)
274+
self.assertAllClose(X_nlzd.min().item(), center - 0.5)
275+
self.assertAllClose(X_nlzd.max().item(), center + 0.5)
273276

274277
nlz.eval()
275278
X_unnlzd = nlz.untransform(X_nlzd)
@@ -278,6 +281,9 @@ def test_normalize(self) -> None:
278281
[X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]],
279282
dim=-2,
280283
)
284+
coeff = expected_bounds[..., 1, :] - expected_bounds[..., 0, :]
285+
expected_bounds[..., 0, :] += (0.5 - center) * coeff
286+
expected_bounds[..., 1, :] = expected_bounds[..., 0, :] + coeff
281287
atol = 1e-6 if dtype is torch.float32 else 1e-12
282288
rtol = 1e-4 if dtype is torch.float32 else 1e-8
283289
self.assertAllClose(nlz.bounds, expected_bounds, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)