Skip to content

Commit 4c6f563

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 4102861 + 4791529 commit 4c6f563

File tree

6 files changed

+35
-15
lines changed

6 files changed

+35
-15
lines changed

.github/unittest/linux_libs/scripts_chess/environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@ dependencies:
1818
- scipy
1919
- hydra-core
2020
- chess
21+
- transformers
22+
- cairosvg

.github/unittest/linux_libs/scripts_chess/install.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ git submodule sync && git submodule update --init --recursive
2828
printf "Installing PyTorch with cu121"
2929
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3030
if [ "${CU_VERSION:-}" == cpu ] ; then
31-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
31+
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
3232
else
33-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
33+
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U
3434
fi
3535
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3636
if [ "${CU_VERSION:-}" == cpu ] ; then
37-
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
37+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
3838
else
39-
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
39+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
4040
fi
4141
else
4242
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_minari/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ dependencies:
1717
- pyyaml
1818
- scipy
1919
- hydra-core
20-
- minari[gcs,hdf5]
20+
- minari[gcs,hdf5,hf]
2121
- gymnasium<1.0.0

test/test_env.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@
172172
mp_ctx = "fork"
173173

174174
_has_chess = importlib.util.find_spec("chess") is not None
175-
175+
_has_tv = importlib.util.find_spec("torchvision") is not None
176+
_has_cairosvg = importlib.util.find_spec("cairosvg") is not None
176177
## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between
177178
## the serial and parallel batched envs
178179
# def _make_atari_env(atari_env):
@@ -3471,6 +3472,15 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
34713472
if include_san:
34723473
assert "san_hash" in env.observation_spec.keys()
34733474

3475+
@pytest.mark.skipif(not _has_tv, reason="torchvision not found.")
3476+
@pytest.mark.skipif(not _has_cairosvg, reason="cairosvg not found.")
3477+
@pytest.mark.parametrize("stateful", [False, True])
3478+
def test_chess_rendering(self, stateful):
3479+
env = ChessEnv(stateful=stateful, include_fen=True, pixels=True)
3480+
env.check_env_specs()
3481+
r = env.rollout(3)
3482+
assert "pixels" in r
3483+
34743484
def test_pgn_bijectivity(self):
34753485
np.random.seed(0)
34763486
pgn = ChessEnv._PGN_RESTART

test/test_libs.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import importlib.util
88
import urllib.error
99

10-
from gym.core import ObsType
11-
1210
_has_isaac = importlib.util.find_spec("isaacgym") is not None
1311

1412
if _has_isaac:
@@ -25,7 +23,7 @@
2523
from contextlib import nullcontext
2624
from pathlib import Path
2725
from sys import platform
28-
from typing import Optional, Tuple, Union
26+
from typing import Optional, Union
2927
from unittest import mock
3028

3129
import numpy as np
@@ -638,7 +636,8 @@ def test_torchrl_to_gym(self, backend, numpy):
638636

639637
@implement_for("gym", None, "0.26")
640638
def test_gym_dict_action_space(self):
641-
pytest.skip("tested for gym > 0.26 - no backward issue")
639+
torchrl_logger.info("tested for gym > 0.26 - no backward issue")
640+
return
642641

643642
@implement_for("gym", "0.26", None)
644643
def test_gym_dict_action_space(self): # noqa: F811
@@ -653,14 +652,17 @@ def __init__(self):
653652
self.observation_space = gym.spaces.Box(-1, 1)
654653

655654
def step(self, action):
655+
assert isinstance(action, dict)
656+
assert isinstance(action["a0"], np.ndarray)
657+
assert isinstance(action["a1"], np.ndarray)
656658
return (0.5, 0.0, False, False, {})
657659

658660
def reset(
659661
self,
660662
*,
661663
seed: Optional[int] = None,
662664
options: Optional[dict] = None,
663-
) -> Tuple[ObsType, dict]:
665+
):
664666
return (0.0, {})
665667

666668
env = CompositeActionEnv()
@@ -686,14 +688,17 @@ def __init__(self):
686688
self.observation_space = gym.spaces.Box(-1, 1)
687689

688690
def step(self, action):
691+
assert isinstance(action, dict)
692+
assert isinstance(action["a0"], np.ndarray)
693+
assert isinstance(action["a1"], np.ndarray)
689694
return (0.5, 0.0, False, False, {})
690695

691696
def reset(
692697
self,
693698
*,
694699
seed: Optional[int] = None,
695700
options: Optional[dict] = None,
696-
) -> Tuple[ObsType, dict]:
701+
):
697702
return (0.0, {})
698703

699704
env = CompositeActionEnv()

torchrl/envs/custom/chess.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Dict
1111

1212
import torch
13-
from PIL import Image
1413
from tensordict import TensorDict, TensorDictBase
1514
from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded
1615

@@ -315,7 +314,9 @@ def __init__(
315314
raise ImportError(
316315
"Please install torchvision to use this environment with pixel rendering."
317316
)
318-
self.full_observation_spec["pixels"] = Unbounded(shape=())
317+
self.full_observation_spec["pixels"] = Unbounded(
318+
shape=(3, 390, 390), dtype=torch.uint8
319+
)
319320

320321
self.full_action_spec = Composite(
321322
action=Categorical(n=len(self.san_moves), shape=(), dtype=torch.int64)
@@ -428,6 +429,8 @@ def _torchvision(cls):
428429
@classmethod
429430
def _get_tensor_image(cls, board):
430431
try:
432+
from PIL import Image
433+
431434
svg = board._repr_svg_()
432435
# Convert SVG to PNG using cairosvg
433436
png_data = io.BytesIO()
@@ -438,7 +441,7 @@ def _get_tensor_image(cls, board):
438441
img = cls._torchvision.transforms.functional.pil_to_tensor(img)
439442
except ImportError:
440443
raise ImportError(
441-
"Chess rendering requires cairosvg and torchvision to be installed."
444+
"Chess rendering requires cairosvg, PIL and torchvision to be installed."
442445
)
443446
return img
444447

0 commit comments

Comments
 (0)