|
70 | 70 | CenterCrop,
|
71 | 71 | ClipTransform,
|
72 | 72 | Compose,
|
| 73 | + Crop, |
73 | 74 | DeviceCastTransform,
|
74 | 75 | DiscreteActionProjection,
|
75 | 76 | DMControlEnv,
|
@@ -2135,6 +2136,213 @@ def test_transform_inverse(self):
|
2135 | 2136 | raise pytest.skip("No inverse for CatTensors")
|
2136 | 2137 |
|
2137 | 2138 |
|
| 2139 | +@pytest.mark.skipif(not _has_tv, reason="no torchvision") |
| 2140 | +class TestCrop(TransformBase): |
| 2141 | + @pytest.mark.parametrize("nchannels", [1, 3]) |
| 2142 | + @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) |
| 2143 | + @pytest.mark.parametrize("h", [None, 21]) |
| 2144 | + @pytest.mark.parametrize( |
| 2145 | + "keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]] |
| 2146 | + ) |
| 2147 | + @pytest.mark.parametrize("device", get_default_devices()) |
| 2148 | + def test_transform_no_env(self, keys, h, nchannels, batch, device): |
| 2149 | + torch.manual_seed(0) |
| 2150 | + dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) |
| 2151 | + crop = Crop(w=20, h=h, in_keys=keys) |
| 2152 | + if h is None: |
| 2153 | + h = 20 |
| 2154 | + td = TensorDict( |
| 2155 | + { |
| 2156 | + key: torch.randn(*batch, nchannels, 16, 16, device=device) |
| 2157 | + for key in keys |
| 2158 | + }, |
| 2159 | + batch, |
| 2160 | + device=device, |
| 2161 | + ) |
| 2162 | + td.set("dont touch", dont_touch.clone()) |
| 2163 | + crop(td) |
| 2164 | + for key in keys: |
| 2165 | + assert td.get(key).shape[-2:] == torch.Size([20, h]) |
| 2166 | + assert (td.get("dont touch") == dont_touch).all() |
| 2167 | + |
| 2168 | + if len(keys) == 1: |
| 2169 | + observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) |
| 2170 | + observation_spec = crop.transform_observation_spec(observation_spec) |
| 2171 | + assert observation_spec.shape == torch.Size([nchannels, 20, h]) |
| 2172 | + else: |
| 2173 | + observation_spec = CompositeSpec( |
| 2174 | + {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} |
| 2175 | + ) |
| 2176 | + observation_spec = crop.transform_observation_spec(observation_spec) |
| 2177 | + for key in keys: |
| 2178 | + assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) |
| 2179 | + |
| 2180 | + @pytest.mark.parametrize("nchannels", [3]) |
| 2181 | + @pytest.mark.parametrize("batch", [[2]]) |
| 2182 | + @pytest.mark.parametrize("h", [None]) |
| 2183 | + @pytest.mark.parametrize("keys", [["observation_pixels"]]) |
| 2184 | + @pytest.mark.parametrize("device", get_default_devices()) |
| 2185 | + def test_transform_model(self, keys, h, nchannels, batch, device): |
| 2186 | + torch.manual_seed(0) |
| 2187 | + dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) |
| 2188 | + crop = Crop(w=20, h=h, in_keys=keys) |
| 2189 | + if h is None: |
| 2190 | + h = 20 |
| 2191 | + td = TensorDict( |
| 2192 | + { |
| 2193 | + key: torch.randn(*batch, nchannels, 16, 16, device=device) |
| 2194 | + for key in keys |
| 2195 | + }, |
| 2196 | + batch, |
| 2197 | + device=device, |
| 2198 | + ) |
| 2199 | + td.set("dont touch", dont_touch.clone()) |
| 2200 | + model = nn.Sequential(crop, nn.Identity()) |
| 2201 | + model(td) |
| 2202 | + for key in keys: |
| 2203 | + assert td.get(key).shape[-2:] == torch.Size([20, h]) |
| 2204 | + assert (td.get("dont touch") == dont_touch).all() |
| 2205 | + |
| 2206 | + @pytest.mark.parametrize("nchannels", [3]) |
| 2207 | + @pytest.mark.parametrize("batch", [[2]]) |
| 2208 | + @pytest.mark.parametrize("h", [None]) |
| 2209 | + @pytest.mark.parametrize("keys", [["observation_pixels"]]) |
| 2210 | + @pytest.mark.parametrize("device", get_default_devices()) |
| 2211 | + def test_transform_compose(self, keys, h, nchannels, batch, device): |
| 2212 | + torch.manual_seed(0) |
| 2213 | + dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) |
| 2214 | + crop = Crop(w=20, h=h, in_keys=keys) |
| 2215 | + if h is None: |
| 2216 | + h = 20 |
| 2217 | + td = TensorDict( |
| 2218 | + { |
| 2219 | + key: torch.randn(*batch, nchannels, 16, 16, device=device) |
| 2220 | + for key in keys |
| 2221 | + }, |
| 2222 | + batch, |
| 2223 | + device=device, |
| 2224 | + ) |
| 2225 | + td.set("dont touch", dont_touch.clone()) |
| 2226 | + model = Compose(crop) |
| 2227 | + tdc = model(td.clone()) |
| 2228 | + for key in keys: |
| 2229 | + assert tdc.get(key).shape[-2:] == torch.Size([20, h]) |
| 2230 | + assert (tdc.get("dont touch") == dont_touch).all() |
| 2231 | + tdc = model._call(td.clone()) |
| 2232 | + for key in keys: |
| 2233 | + assert tdc.get(key).shape[-2:] == torch.Size([20, h]) |
| 2234 | + assert (tdc.get("dont touch") == dont_touch).all() |
| 2235 | + |
| 2236 | + @pytest.mark.parametrize("nchannels", [3]) |
| 2237 | + @pytest.mark.parametrize("batch", [[2]]) |
| 2238 | + @pytest.mark.parametrize("h", [None]) |
| 2239 | + @pytest.mark.parametrize("keys", [["observation_pixels"]]) |
| 2240 | + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) |
| 2241 | + def test_transform_rb( |
| 2242 | + self, |
| 2243 | + rbclass, |
| 2244 | + keys, |
| 2245 | + h, |
| 2246 | + nchannels, |
| 2247 | + batch, |
| 2248 | + ): |
| 2249 | + torch.manual_seed(0) |
| 2250 | + dont_touch = torch.randn( |
| 2251 | + *batch, |
| 2252 | + nchannels, |
| 2253 | + 16, |
| 2254 | + 16, |
| 2255 | + ) |
| 2256 | + crop = Crop(w=20, h=h, in_keys=keys) |
| 2257 | + if h is None: |
| 2258 | + h = 20 |
| 2259 | + td = TensorDict( |
| 2260 | + { |
| 2261 | + key: torch.randn( |
| 2262 | + *batch, |
| 2263 | + nchannels, |
| 2264 | + 16, |
| 2265 | + 16, |
| 2266 | + ) |
| 2267 | + for key in keys |
| 2268 | + }, |
| 2269 | + batch, |
| 2270 | + ) |
| 2271 | + td.set("dont touch", dont_touch.clone()) |
| 2272 | + rb = rbclass(storage=LazyTensorStorage(10)) |
| 2273 | + rb.append_transform(crop) |
| 2274 | + rb.extend(td) |
| 2275 | + td = rb.sample(10) |
| 2276 | + for key in keys: |
| 2277 | + assert td.get(key).shape[-2:] == torch.Size([20, h]) |
| 2278 | + |
| 2279 | + def test_single_trans_env_check(self): |
| 2280 | + keys = ["pixels"] |
| 2281 | + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) |
| 2282 | + env = TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) |
| 2283 | + check_env_specs(env) |
| 2284 | + |
| 2285 | + def test_serial_trans_env_check(self): |
| 2286 | + keys = ["pixels"] |
| 2287 | + |
| 2288 | + def make_env(): |
| 2289 | + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) |
| 2290 | + return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) |
| 2291 | + |
| 2292 | + env = SerialEnv(2, make_env) |
| 2293 | + check_env_specs(env) |
| 2294 | + |
| 2295 | + def test_parallel_trans_env_check(self): |
| 2296 | + keys = ["pixels"] |
| 2297 | + |
| 2298 | + def make_env(): |
| 2299 | + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) |
| 2300 | + return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) |
| 2301 | + |
| 2302 | + env = ParallelEnv(2, make_env) |
| 2303 | + try: |
| 2304 | + check_env_specs(env) |
| 2305 | + finally: |
| 2306 | + try: |
| 2307 | + env.close() |
| 2308 | + except RuntimeError: |
| 2309 | + pass |
| 2310 | + |
| 2311 | + def test_trans_serial_env_check(self): |
| 2312 | + keys = ["pixels"] |
| 2313 | + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) |
| 2314 | + env = TransformedEnv(SerialEnv(2, DiscreteActionConvMockEnvNumpy), ct) |
| 2315 | + check_env_specs(env) |
| 2316 | + |
| 2317 | + def test_trans_parallel_env_check(self): |
| 2318 | + keys = ["pixels"] |
| 2319 | + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) |
| 2320 | + env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct) |
| 2321 | + try: |
| 2322 | + check_env_specs(env) |
| 2323 | + finally: |
| 2324 | + try: |
| 2325 | + env.close() |
| 2326 | + except RuntimeError: |
| 2327 | + pass |
| 2328 | + |
| 2329 | + @pytest.mark.skipif(not _has_gym, reason="No Gym detected") |
| 2330 | + @pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]]) |
| 2331 | + def test_transform_env(self, out_key): |
| 2332 | + keys = ["pixels"] |
| 2333 | + ct = Compose(ToTensorImage(), Crop(out_keys=out_key, w=20, h=20, in_keys=keys)) |
| 2334 | + env = TransformedEnv(GymEnv(PONG_VERSIONED()), ct) |
| 2335 | + td = env.reset() |
| 2336 | + if out_key is None: |
| 2337 | + assert td["pixels"].shape == torch.Size([3, 20, 20]) |
| 2338 | + else: |
| 2339 | + assert td[out_key[0]].shape == torch.Size([3, 20, 20]) |
| 2340 | + check_env_specs(env) |
| 2341 | + |
| 2342 | + def test_transform_inverse(self): |
| 2343 | + raise pytest.skip("Crop does not have an inverse method.") |
| 2344 | + |
| 2345 | + |
2138 | 2346 | @pytest.mark.skipif(not _has_tv, reason="no torchvision")
|
2139 | 2347 | class TestCenterCrop(TransformBase):
|
2140 | 2348 | @pytest.mark.parametrize("nchannels", [1, 3])
|
|
0 commit comments