Skip to content

Commit ddbb6fd

Browse files
author
Vincent Moens
committed
[BugFix,Test] test chess rendering
ghstack-source-id: 59b37e6 Pull Request resolved: #2721
1 parent d628a50 commit ddbb6fd

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

.github/unittest/linux_libs/scripts_chess/environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@ dependencies:
1818
- scipy
1919
- hydra-core
2020
- chess
21+
- transformers
22+
- cairosvg

.github/unittest/linux_libs/scripts_chess/install.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ git submodule sync && git submodule update --init --recursive
2828
printf "Installing PyTorch with cu121"
2929
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3030
if [ "${CU_VERSION:-}" == cpu ] ; then
31-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
31+
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
3232
else
33-
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
33+
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U
3434
fi
3535
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3636
if [ "${CU_VERSION:-}" == cpu ] ; then
37-
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
37+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
3838
else
39-
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
39+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
4040
fi
4141
else
4242
printf "Failed to install pytorch"

test/test_env.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@
172172
mp_ctx = "fork"
173173

174174
_has_chess = importlib.util.find_spec("chess") is not None
175-
175+
_has_tv = importlib.util.find_spec("torchvision") is not None
176+
_has_cairosvg = importlib.util.find_spec("cairosvg") is not None
176177
## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between
177178
## the serial and parallel batched envs
178179
# def _make_atari_env(atari_env):
@@ -3471,6 +3472,15 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
34713472
if include_san:
34723473
assert "san_hash" in env.observation_spec.keys()
34733474

3475+
@pytest.mark.skipif(not _has_tv, reason="torchvision not found.")
3476+
@pytest.mark.skipif(not _has_cairosvg, reason="cairosvg not found.")
3477+
@pytest.mark.parametrize("stateful", [False, True])
3478+
def test_chess_rendering(self, stateful):
3479+
env = ChessEnv(stateful=stateful, include_fen=True, pixels=True)
3480+
env.check_env_specs()
3481+
r = env.rollout(3)
3482+
assert "pixels" in r
3483+
34743484
def test_pgn_bijectivity(self):
34753485
np.random.seed(0)
34763486
pgn = ChessEnv._PGN_RESTART

torchrl/envs/custom/chess.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,9 @@ def __init__(
316316
raise ImportError(
317317
"Please install torchvision to use this environment with pixel rendering."
318318
)
319-
self.full_observation_spec["pixels"] = Unbounded(shape=())
319+
self.full_observation_spec["pixels"] = Unbounded(
320+
shape=(3, 390, 390), dtype=torch.uint8
321+
)
320322

321323
self.full_action_spec = Composite(
322324
action=Categorical(n=len(self.san_moves), shape=(), dtype=torch.int64)

0 commit comments

Comments
 (0)