Skip to content

Commit bf91ff6

Browse files
albertbou92vmoens
andauthored
[Feature] Crop Transform (#2336)
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent b3f99b3 commit bf91ff6

File tree

5 files changed

+278
-0
lines changed

5 files changed

+278
-0
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,7 @@ to be able to create this other composition:
793793
CenterCrop
794794
ClipTransform
795795
Compose
796+
Crop
796797
DTypeCastTransform
797798
DeviceCastTransform
798799
DiscreteActionProjection

test/test_transforms.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
CenterCrop,
7171
ClipTransform,
7272
Compose,
73+
Crop,
7374
DeviceCastTransform,
7475
DiscreteActionProjection,
7576
DMControlEnv,
@@ -2135,6 +2136,213 @@ def test_transform_inverse(self):
21352136
raise pytest.skip("No inverse for CatTensors")
21362137

21372138

2139+
@pytest.mark.skipif(not _has_tv, reason="no torchvision")
2140+
class TestCrop(TransformBase):
2141+
@pytest.mark.parametrize("nchannels", [1, 3])
2142+
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
2143+
@pytest.mark.parametrize("h", [None, 21])
2144+
@pytest.mark.parametrize(
2145+
"keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]]
2146+
)
2147+
@pytest.mark.parametrize("device", get_default_devices())
2148+
def test_transform_no_env(self, keys, h, nchannels, batch, device):
2149+
torch.manual_seed(0)
2150+
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
2151+
crop = Crop(w=20, h=h, in_keys=keys)
2152+
if h is None:
2153+
h = 20
2154+
td = TensorDict(
2155+
{
2156+
key: torch.randn(*batch, nchannels, 16, 16, device=device)
2157+
for key in keys
2158+
},
2159+
batch,
2160+
device=device,
2161+
)
2162+
td.set("dont touch", dont_touch.clone())
2163+
crop(td)
2164+
for key in keys:
2165+
assert td.get(key).shape[-2:] == torch.Size([20, h])
2166+
assert (td.get("dont touch") == dont_touch).all()
2167+
2168+
if len(keys) == 1:
2169+
observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16))
2170+
observation_spec = crop.transform_observation_spec(observation_spec)
2171+
assert observation_spec.shape == torch.Size([nchannels, 20, h])
2172+
else:
2173+
observation_spec = CompositeSpec(
2174+
{key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys}
2175+
)
2176+
observation_spec = crop.transform_observation_spec(observation_spec)
2177+
for key in keys:
2178+
assert observation_spec[key].shape == torch.Size([nchannels, 20, h])
2179+
2180+
@pytest.mark.parametrize("nchannels", [3])
2181+
@pytest.mark.parametrize("batch", [[2]])
2182+
@pytest.mark.parametrize("h", [None])
2183+
@pytest.mark.parametrize("keys", [["observation_pixels"]])
2184+
@pytest.mark.parametrize("device", get_default_devices())
2185+
def test_transform_model(self, keys, h, nchannels, batch, device):
2186+
torch.manual_seed(0)
2187+
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
2188+
crop = Crop(w=20, h=h, in_keys=keys)
2189+
if h is None:
2190+
h = 20
2191+
td = TensorDict(
2192+
{
2193+
key: torch.randn(*batch, nchannels, 16, 16, device=device)
2194+
for key in keys
2195+
},
2196+
batch,
2197+
device=device,
2198+
)
2199+
td.set("dont touch", dont_touch.clone())
2200+
model = nn.Sequential(crop, nn.Identity())
2201+
model(td)
2202+
for key in keys:
2203+
assert td.get(key).shape[-2:] == torch.Size([20, h])
2204+
assert (td.get("dont touch") == dont_touch).all()
2205+
2206+
@pytest.mark.parametrize("nchannels", [3])
2207+
@pytest.mark.parametrize("batch", [[2]])
2208+
@pytest.mark.parametrize("h", [None])
2209+
@pytest.mark.parametrize("keys", [["observation_pixels"]])
2210+
@pytest.mark.parametrize("device", get_default_devices())
2211+
def test_transform_compose(self, keys, h, nchannels, batch, device):
2212+
torch.manual_seed(0)
2213+
dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device)
2214+
crop = Crop(w=20, h=h, in_keys=keys)
2215+
if h is None:
2216+
h = 20
2217+
td = TensorDict(
2218+
{
2219+
key: torch.randn(*batch, nchannels, 16, 16, device=device)
2220+
for key in keys
2221+
},
2222+
batch,
2223+
device=device,
2224+
)
2225+
td.set("dont touch", dont_touch.clone())
2226+
model = Compose(crop)
2227+
tdc = model(td.clone())
2228+
for key in keys:
2229+
assert tdc.get(key).shape[-2:] == torch.Size([20, h])
2230+
assert (tdc.get("dont touch") == dont_touch).all()
2231+
tdc = model._call(td.clone())
2232+
for key in keys:
2233+
assert tdc.get(key).shape[-2:] == torch.Size([20, h])
2234+
assert (tdc.get("dont touch") == dont_touch).all()
2235+
2236+
@pytest.mark.parametrize("nchannels", [3])
2237+
@pytest.mark.parametrize("batch", [[2]])
2238+
@pytest.mark.parametrize("h", [None])
2239+
@pytest.mark.parametrize("keys", [["observation_pixels"]])
2240+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
2241+
def test_transform_rb(
2242+
self,
2243+
rbclass,
2244+
keys,
2245+
h,
2246+
nchannels,
2247+
batch,
2248+
):
2249+
torch.manual_seed(0)
2250+
dont_touch = torch.randn(
2251+
*batch,
2252+
nchannels,
2253+
16,
2254+
16,
2255+
)
2256+
crop = Crop(w=20, h=h, in_keys=keys)
2257+
if h is None:
2258+
h = 20
2259+
td = TensorDict(
2260+
{
2261+
key: torch.randn(
2262+
*batch,
2263+
nchannels,
2264+
16,
2265+
16,
2266+
)
2267+
for key in keys
2268+
},
2269+
batch,
2270+
)
2271+
td.set("dont touch", dont_touch.clone())
2272+
rb = rbclass(storage=LazyTensorStorage(10))
2273+
rb.append_transform(crop)
2274+
rb.extend(td)
2275+
td = rb.sample(10)
2276+
for key in keys:
2277+
assert td.get(key).shape[-2:] == torch.Size([20, h])
2278+
2279+
def test_single_trans_env_check(self):
2280+
keys = ["pixels"]
2281+
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
2282+
env = TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct)
2283+
check_env_specs(env)
2284+
2285+
def test_serial_trans_env_check(self):
2286+
keys = ["pixels"]
2287+
2288+
def make_env():
2289+
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
2290+
return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct)
2291+
2292+
env = SerialEnv(2, make_env)
2293+
check_env_specs(env)
2294+
2295+
def test_parallel_trans_env_check(self):
2296+
keys = ["pixels"]
2297+
2298+
def make_env():
2299+
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
2300+
return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct)
2301+
2302+
env = ParallelEnv(2, make_env)
2303+
try:
2304+
check_env_specs(env)
2305+
finally:
2306+
try:
2307+
env.close()
2308+
except RuntimeError:
2309+
pass
2310+
2311+
def test_trans_serial_env_check(self):
2312+
keys = ["pixels"]
2313+
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
2314+
env = TransformedEnv(SerialEnv(2, DiscreteActionConvMockEnvNumpy), ct)
2315+
check_env_specs(env)
2316+
2317+
def test_trans_parallel_env_check(self):
2318+
keys = ["pixels"]
2319+
ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys))
2320+
env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct)
2321+
try:
2322+
check_env_specs(env)
2323+
finally:
2324+
try:
2325+
env.close()
2326+
except RuntimeError:
2327+
pass
2328+
2329+
@pytest.mark.skipif(not _has_gym, reason="No Gym detected")
2330+
@pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]])
2331+
def test_transform_env(self, out_key):
2332+
keys = ["pixels"]
2333+
ct = Compose(ToTensorImage(), Crop(out_keys=out_key, w=20, h=20, in_keys=keys))
2334+
env = TransformedEnv(GymEnv(PONG_VERSIONED()), ct)
2335+
td = env.reset()
2336+
if out_key is None:
2337+
assert td["pixels"].shape == torch.Size([3, 20, 20])
2338+
else:
2339+
assert td[out_key[0]].shape == torch.Size([3, 20, 20])
2340+
check_env_specs(env)
2341+
2342+
def test_transform_inverse(self):
2343+
raise pytest.skip("Crop does not have an inverse method.")
2344+
2345+
21382346
@pytest.mark.skipif(not _has_tv, reason="no torchvision")
21392347
class TestCenterCrop(TransformBase):
21402348
@pytest.mark.parametrize("nchannels", [1, 3])

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
CenterCrop,
5252
ClipTransform,
5353
Compose,
54+
Crop,
5455
DeviceCastTransform,
5556
DiscreteActionProjection,
5657
DoubleToFloat,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CenterCrop,
2121
ClipTransform,
2222
Compose,
23+
Crop,
2324
DeviceCastTransform,
2425
DiscreteActionProjection,
2526
DoubleToFloat,

torchrl/envs/transforms/transforms.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,73 @@ def _reset(
19131913
return tensordict_reset
19141914

19151915

1916+
class Crop(ObservationTransform):
1917+
"""Crops the input image at the specified location and output size.
1918+
1919+
Args:
1920+
w (int): resulting width
1921+
h (int, optional): resulting height. If None, then w is used (square crop).
1922+
top (int, optional): top pixel coordinate to start cropping. Default is 0, i.e. top of the image.
1923+
left (int, optional): left pixel coordinate to start cropping. Default is 0, i.e. left of the image.
1924+
in_keys (sequence of NestedKey, optional): the entries to crop. If none is provided,
1925+
``["pixels"]`` is assumed.
1926+
out_keys (sequence of NestedKey, optional): the cropped images keys. If none is
1927+
provided, ``in_keys`` is assumed.
1928+
1929+
"""
1930+
1931+
def __init__(
1932+
self,
1933+
w: int,
1934+
h: int = None,
1935+
top: int = 0,
1936+
left: int = 0,
1937+
in_keys: Sequence[NestedKey] | None = None,
1938+
out_keys: Sequence[NestedKey] | None = None,
1939+
):
1940+
if in_keys is None:
1941+
in_keys = IMAGE_KEYS # default
1942+
if out_keys is None:
1943+
out_keys = copy(in_keys)
1944+
super().__init__(in_keys=in_keys, out_keys=out_keys)
1945+
self.w = w
1946+
self.h = h if h else w
1947+
self.top = top
1948+
self.left = left
1949+
1950+
def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
1951+
from torchvision.transforms.functional import crop
1952+
1953+
observation = crop(observation, self.top, self.left, self.w, self.h)
1954+
return observation
1955+
1956+
def _reset(
1957+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
1958+
) -> TensorDictBase:
1959+
with _set_missing_tolerance(self, True):
1960+
tensordict_reset = self._call(tensordict_reset)
1961+
return tensordict_reset
1962+
1963+
@_apply_to_composite
1964+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
1965+
space = observation_spec.space
1966+
if isinstance(space, ContinuousBox):
1967+
space.low = self._apply_transform(space.low)
1968+
space.high = self._apply_transform(space.high)
1969+
observation_spec.shape = space.low.shape
1970+
else:
1971+
observation_spec.shape = self._apply_transform(
1972+
torch.zeros(observation_spec.shape)
1973+
).shape
1974+
return observation_spec
1975+
1976+
def __repr__(self) -> str:
1977+
return (
1978+
f"{self.__class__.__name__}("
1979+
f"w={float(self.w):4.4f}, h={float(self.h):4.4f}, top={float(self.top):4.4f}, left={float(self.left):4.4f}, "
1980+
)
1981+
1982+
19161983
class CenterCrop(ObservationTransform):
19171984
"""Crops the center of an image.
19181985

0 commit comments

Comments
 (0)