Skip to content

Commit f741deb

Browse files
authored
[Feature] Loading R3M and VIP from ResNet (#863)
1 parent 3b64392 commit f741deb

File tree

3 files changed

+122
-44
lines changed

3 files changed

+122
-44
lines changed

test/test_transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,6 +1677,10 @@ def test_pin_mem(self, device):
16771677
td = TensorDict(
16781678
{key: torch.randn(3) for key in ["a", "b", "c"]}, [], device=device
16791679
)
1680+
if device.type == "cuda":
1681+
with pytest.raises(RuntimeError, match="cannot pin"):
1682+
pin_mem(td)
1683+
return
16801684
pin_mem(td)
16811685
for item in td.values():
16821686
assert item.is_pinned

torchrl/envs/transforms/r3m.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
except ImportError:
3535
_has_tv = False
3636

37+
try:
38+
from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights
39+
from torchvision.models._api import WeightsEnum
40+
except ImportError:
41+
42+
class WeightsEnum: # noqa: D101
43+
# placeholder
44+
pass
45+
46+
47+
R3M_MODEL_MAP = {
48+
"resnet18": "r3m_18",
49+
"resnet34": "r3m_34",
50+
"resnet50": "r3m_50",
51+
}
52+
3753

3854
class _R3MNet(Transform):
3955

@@ -45,18 +61,19 @@ def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True):
4561
"Tried to instantiate R3M without torchvision. Make sure you have "
4662
"torchvision installed in your environment."
4763
)
64+
self.model_name = model_name
4865
if model_name == "resnet18":
49-
self.model_name = "r3m_18"
66+
# self.model_name = "r3m_18"
5067
self.outdim = 512
51-
convnet = models.resnet18(pretrained=False)
68+
convnet = models.resnet18(None)
5269
elif model_name == "resnet34":
53-
self.model_name = "r3m_34"
70+
# self.model_name = "r3m_34"
5471
self.outdim = 512
55-
convnet = models.resnet34(pretrained=False)
72+
convnet = models.resnet34(None)
5673
elif model_name == "resnet50":
57-
self.model_name = "r3m_50"
74+
# self.model_name = "r3m_50"
5875
self.outdim = 2048
59-
convnet = models.resnet50(pretrained=False)
76+
convnet = models.resnet50(None)
6077
else:
6178
raise NotImplementedError(
6279
f"model {model_name} is currently not supported by R3M"
@@ -123,8 +140,34 @@ def _load_weights(model_name, r3m_instance, dir_prefix):
123140
state_dict = td_flatten.to_dict()
124141
r3m_instance.convnet.load_state_dict(state_dict)
125142

126-
def load_weights(self, dir_prefix=None):
127-
self._load_weights(self.model_name, self, dir_prefix)
143+
def load_weights(self, dir_prefix=None, tv_weights=None):
144+
if dir_prefix is not None and tv_weights is not None:
145+
raise RuntimeError(
146+
"torchvision weights API does not allow for custom download path."
147+
)
148+
elif tv_weights is not None:
149+
model_name = self.model_name
150+
if model_name == "resnet18":
151+
if isinstance(tv_weights, str):
152+
tv_weights = getattr(ResNet18_Weights, tv_weights)
153+
convnet = models.resnet18(weights=tv_weights)
154+
elif model_name == "resnet34":
155+
if isinstance(tv_weights, str):
156+
tv_weights = getattr(ResNet34_Weights, tv_weights)
157+
convnet = models.resnet34(weights=tv_weights)
158+
elif model_name == "resnet50":
159+
if isinstance(tv_weights, str):
160+
tv_weights = getattr(ResNet50_Weights, tv_weights)
161+
convnet = models.resnet50(weights=tv_weights)
162+
else:
163+
raise NotImplementedError(
164+
f"model {model_name} is currently not supported by R3M"
165+
)
166+
convnet.fc = Identity()
167+
self.convnet.load_state_dict(convnet.state_dict())
168+
else:
169+
model_name = R3M_MODEL_MAP[self.model_name]
170+
self._load_weights(model_name, self, dir_prefix)
128171

129172

130173
def _init_first(fun):
@@ -154,7 +197,7 @@ class R3MTransform(Compose):
154197
can ensure that the following code snippet works as expected:
155198
156199
Examples:
157-
>>> transform = R3MTransform("resenet50", in_keys=["pixels"])
200+
>>> transform = R3MTransform("resnet50", in_keys=["pixels"])
158201
>>> env.append_transform(transform)
159202
>>> # the forward method will first call _init which will look at env.observation_spec
160203
>>> env.reset()
@@ -170,8 +213,13 @@ class R3MTransform(Compose):
170213
stack_images (bool, optional): if False, the images given in the :obj:`in_keys`
171214
argument will be treaded separetely and each will be given a single,
172215
separated entry in the output tensordict. Defaults to :obj:`True`.
173-
download (bool, optional): if True, the weights will be downloaded using
174-
the torch.hub download API (i.e. weights will be cached for future use).
216+
download (bool, torchvision Weights config or corresponding string):
217+
if True, the weights will be downloaded using the torch.hub download
218+
API (i.e. weights will be cached for future use).
219+
These weights are the original weights from the R3M publication.
220+
If the torchvision weights are needed, there are two ways they can be
221+
obtained: :obj:`download=ResNet50_Weights.IMAGENET1K_V1` or :obj:`download="IMAGENET1K_V1"`
222+
where :obj:`ResNet50_Weights` can be imported via :obj:`from torchvision.models import resnet50, ResNet50_Weights`.
175223
Defaults to False.
176224
download_path (str, optional): path where to download the models.
177225
Default is None (cache path determined by torch.hub utils).
@@ -194,7 +242,7 @@ def __init__(
194242
out_keys: List[str] = None,
195243
size: int = 244,
196244
stack_images: bool = True,
197-
download: bool = False,
245+
download: Union[bool, WeightsEnum, str] = False,
198246
download_path: Optional[str] = None,
199247
tensor_pixels_keys: List[str] = None,
200248
):
@@ -302,8 +350,12 @@ def _init(self):
302350

303351
for transform in transforms:
304352
self.append(transform)
305-
if self.download:
306-
self[-1].load_weights(dir_prefix=self.download_path)
353+
if self.download is True:
354+
self[-1].load_weights(dir_prefix=self.download_path, tv_weights=None)
355+
elif self.download:
356+
self[-1].load_weights(
357+
dir_prefix=self.download_path, tv_weights=self.download
358+
)
307359

308360
if self._device is not None:
309361
self.to(self._device)

torchrl/envs/transforms/vip.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@
3434
except ImportError:
3535
_has_tv = False
3636

37+
try:
38+
from torchvision.models import ResNet50_Weights
39+
from torchvision.models._api import WeightsEnum
40+
except ImportError:
41+
42+
class WeightsEnum: # noqa: D101
43+
# placeholder
44+
pass
45+
46+
47+
VIP_MODEL_MAP = {
48+
"resnet50": "vip_50",
49+
}
50+
3751

3852
class _VIPNet(Transform):
3953

@@ -45,8 +59,8 @@ def __init__(self, in_keys, out_keys, model_name="resnet50", del_keys: bool = Tr
4559
"Tried to instantiate VIP without torchvision. Make sure you have "
4660
"torchvision installed in your environment."
4761
)
62+
self.model_name = model_name
4863
if model_name == "resnet50":
49-
self.model_name = "vip_50"
5064
self.outdim = 2048
5165
convnet = models.resnet50(pretrained=False)
5266
convnet.fc = torch.nn.Linear(self.outdim, 1024)
@@ -98,8 +112,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
98112

99113
@staticmethod
100114
def _load_weights(model_name, vip_instance, dir_prefix):
101-
if model_name not in ("vip_50"):
102-
raise ValueError("model_name should be 'vip_50'")
115+
if model_name not in ("vip_50",):
116+
raise ValueError(f"model_name should be 'vip_50', got {model_name}")
103117
url = "https://pytorch.s3.amazonaws.com/models/rl/vip/model.pt"
104118
d = load_state_dict_from_url(
105119
url,
@@ -112,8 +126,27 @@ def _load_weights(model_name, vip_instance, dir_prefix):
112126
state_dict = td_flatten.to_dict()
113127
vip_instance.convnet.load_state_dict(state_dict)
114128

115-
def load_weights(self, dir_prefix=None):
116-
self._load_weights(self.model_name, self, dir_prefix)
129+
def load_weights(self, dir_prefix=None, tv_weights=None):
130+
if dir_prefix is not None and tv_weights is not None:
131+
raise RuntimeError(
132+
"torchvision weights API does not allow for custom download path."
133+
)
134+
elif tv_weights is not None:
135+
model_name = self.model_name
136+
if model_name == "resnet50":
137+
if isinstance(tv_weights, str):
138+
tv_weights = getattr(ResNet50_Weights, tv_weights)
139+
convnet = models.resnet50(weights=tv_weights)
140+
else:
141+
raise NotImplementedError(
142+
f"model {model_name} is currently not supported by R3M"
143+
)
144+
convnet.fc = torch.nn.Linear(self.outdim, 1024)
145+
self.convnet.load_state_dict(convnet.state_dict())
146+
147+
else:
148+
model_name = VIP_MODEL_MAP[self.model_name]
149+
self._load_weights(model_name, self, dir_prefix)
117150

118151

119152
def _init_first(fun):
@@ -145,8 +178,13 @@ class VIPTransform(Compose):
145178
stack_images (bool, optional): if False, the images given in the :obj:`in_keys`
146179
argument will be treaded separetely and each will be given a single,
147180
separated entry in the output tensordict. Defaults to :obj:`True`.
148-
download (bool, optional): if True, the weights will be downloaded using
149-
the torch.hub download API (i.e. weights will be cached for future use).
181+
download (bool, torchvision Weights config or corresponding string):
182+
if True, the weights will be downloaded using the torch.hub download
183+
API (i.e. weights will be cached for future use).
184+
These weights are the original weights from the VIP publication.
185+
If the torchvision weights are needed, there are two ways they can be
186+
obtained: :obj:`download=ResNet50_Weights.IMAGENET1K_V1` or :obj:`download="IMAGENET1K_V1"`
187+
where :obj:`ResNet50_Weights` can be imported via :obj:`from torchvision.models import resnet50, ResNet50_Weights`.
150188
Defaults to False.
151189
download_path (str, optional): path where to download the models.
152190
Default is None (cache path determined by torch.hub utils).
@@ -169,7 +207,7 @@ def __init__(
169207
out_keys: List[str] = None,
170208
size: int = 244,
171209
stack_images: bool = True,
172-
download: bool = False,
210+
download: Union[bool, WeightsEnum, str] = False,
173211
download_path: Optional[str] = None,
174212
tensor_pixels_keys: List[str] = None,
175213
):
@@ -275,34 +313,18 @@ def _init(self):
275313

276314
for transform in transforms:
277315
self.append(transform)
278-
if self.download:
279-
self[-1].load_weights(dir_prefix=self.download_path)
316+
if self.download is True:
317+
self[-1].load_weights(dir_prefix=self.download_path, tv_weights=None)
318+
elif self.download:
319+
self[-1].load_weights(
320+
dir_prefix=self.download_path, tv_weights=self.download
321+
)
280322

281323
if self._device is not None:
282324
self.to(self._device)
283325
if self._dtype is not None:
284326
self.to(self._dtype)
285327

286-
@property
287-
def is_3d(self):
288-
"""Whether the input image has 3 dims (no-batched) or more.
289-
290-
If no parent environment exists, it defaults to True.
291-
292-
The main usage is this: if there are more than one image and they need to be
293-
stacked, we must know if the input image has dim 3 or 4. If 3, we need to unsqueeze
294-
before stacking. If 4, we can cat along the first dimension.
295-
296-
"""
297-
if self._is_3d is None:
298-
parent = self.parent
299-
if parent is None:
300-
return True
301-
for key in parent.observation_spec.keys():
302-
self._is_3d = len(parent.observation_spec[key].shape) == 3
303-
break
304-
return self._is_3d
305-
306328
def to(self, dest: Union[DEVICE_TYPING, torch.dtype]):
307329
if isinstance(dest, torch.dtype):
308330
self._dtype = dest

0 commit comments

Comments
 (0)