34
34
except ImportError :
35
35
_has_tv = False
36
36
37
+ try :
38
+ from torchvision .models import ResNet18_Weights , ResNet34_Weights , ResNet50_Weights
39
+ from torchvision .models ._api import WeightsEnum
40
+ except ImportError :
41
+
42
+ class WeightsEnum : # noqa: D101
43
+ # placeholder
44
+ pass
45
+
46
+
47
+ R3M_MODEL_MAP = {
48
+ "resnet18" : "r3m_18" ,
49
+ "resnet34" : "r3m_34" ,
50
+ "resnet50" : "r3m_50" ,
51
+ }
52
+
37
53
38
54
class _R3MNet (Transform ):
39
55
@@ -45,18 +61,19 @@ def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True):
45
61
"Tried to instantiate R3M without torchvision. Make sure you have "
46
62
"torchvision installed in your environment."
47
63
)
64
+ self .model_name = model_name
48
65
if model_name == "resnet18" :
49
- self .model_name = "r3m_18"
66
+ # self.model_name = "r3m_18"
50
67
self .outdim = 512
51
- convnet = models .resnet18 (pretrained = False )
68
+ convnet = models .resnet18 (None )
52
69
elif model_name == "resnet34" :
53
- self .model_name = "r3m_34"
70
+ # self.model_name = "r3m_34"
54
71
self .outdim = 512
55
- convnet = models .resnet34 (pretrained = False )
72
+ convnet = models .resnet34 (None )
56
73
elif model_name == "resnet50" :
57
- self .model_name = "r3m_50"
74
+ # self.model_name = "r3m_50"
58
75
self .outdim = 2048
59
- convnet = models .resnet50 (pretrained = False )
76
+ convnet = models .resnet50 (None )
60
77
else :
61
78
raise NotImplementedError (
62
79
f"model { model_name } is currently not supported by R3M"
@@ -123,8 +140,34 @@ def _load_weights(model_name, r3m_instance, dir_prefix):
123
140
state_dict = td_flatten .to_dict ()
124
141
r3m_instance .convnet .load_state_dict (state_dict )
125
142
126
- def load_weights (self , dir_prefix = None ):
127
- self ._load_weights (self .model_name , self , dir_prefix )
143
+ def load_weights (self , dir_prefix = None , tv_weights = None ):
144
+ if dir_prefix is not None and tv_weights is not None :
145
+ raise RuntimeError (
146
+ "torchvision weights API does not allow for custom download path."
147
+ )
148
+ elif tv_weights is not None :
149
+ model_name = self .model_name
150
+ if model_name == "resnet18" :
151
+ if isinstance (tv_weights , str ):
152
+ tv_weights = getattr (ResNet18_Weights , tv_weights )
153
+ convnet = models .resnet18 (weights = tv_weights )
154
+ elif model_name == "resnet34" :
155
+ if isinstance (tv_weights , str ):
156
+ tv_weights = getattr (ResNet34_Weights , tv_weights )
157
+ convnet = models .resnet34 (weights = tv_weights )
158
+ elif model_name == "resnet50" :
159
+ if isinstance (tv_weights , str ):
160
+ tv_weights = getattr (ResNet50_Weights , tv_weights )
161
+ convnet = models .resnet50 (weights = tv_weights )
162
+ else :
163
+ raise NotImplementedError (
164
+ f"model { model_name } is currently not supported by R3M"
165
+ )
166
+ convnet .fc = Identity ()
167
+ self .convnet .load_state_dict (convnet .state_dict ())
168
+ else :
169
+ model_name = R3M_MODEL_MAP [self .model_name ]
170
+ self ._load_weights (model_name , self , dir_prefix )
128
171
129
172
130
173
def _init_first (fun ):
@@ -154,7 +197,7 @@ class R3MTransform(Compose):
154
197
can ensure that the following code snippet works as expected:
155
198
156
199
Examples:
157
- >>> transform = R3MTransform("resenet50 ", in_keys=["pixels"])
200
+ >>> transform = R3MTransform("resnet50 ", in_keys=["pixels"])
158
201
>>> env.append_transform(transform)
159
202
>>> # the forward method will first call _init which will look at env.observation_spec
160
203
>>> env.reset()
@@ -170,8 +213,13 @@ class R3MTransform(Compose):
170
213
stack_images (bool, optional): if False, the images given in the :obj:`in_keys`
171
214
argument will be treaded separetely and each will be given a single,
172
215
separated entry in the output tensordict. Defaults to :obj:`True`.
173
- download (bool, optional): if True, the weights will be downloaded using
174
- the torch.hub download API (i.e. weights will be cached for future use).
216
+ download (bool, torchvision Weights config or corresponding string):
217
+ if True, the weights will be downloaded using the torch.hub download
218
+ API (i.e. weights will be cached for future use).
219
+ These weights are the original weights from the R3M publication.
220
+ If the torchvision weights are needed, there are two ways they can be
221
+ obtained: :obj:`download=ResNet50_Weights.IMAGENET1K_V1` or :obj:`download="IMAGENET1K_V1"`
222
+ where :obj:`ResNet50_Weights` can be imported via :obj:`from torchvision.models import resnet50, ResNet50_Weights`.
175
223
Defaults to False.
176
224
download_path (str, optional): path where to download the models.
177
225
Default is None (cache path determined by torch.hub utils).
@@ -194,7 +242,7 @@ def __init__(
194
242
out_keys : List [str ] = None ,
195
243
size : int = 244 ,
196
244
stack_images : bool = True ,
197
- download : bool = False ,
245
+ download : Union [ bool , WeightsEnum , str ] = False ,
198
246
download_path : Optional [str ] = None ,
199
247
tensor_pixels_keys : List [str ] = None ,
200
248
):
@@ -302,8 +350,12 @@ def _init(self):
302
350
303
351
for transform in transforms :
304
352
self .append (transform )
305
- if self .download :
306
- self [- 1 ].load_weights (dir_prefix = self .download_path )
353
+ if self .download is True :
354
+ self [- 1 ].load_weights (dir_prefix = self .download_path , tv_weights = None )
355
+ elif self .download :
356
+ self [- 1 ].load_weights (
357
+ dir_prefix = self .download_path , tv_weights = self .download
358
+ )
307
359
308
360
if self ._device is not None :
309
361
self .to (self ._device )
0 commit comments