Skip to content

Commit 7dfae59

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent 256a700 commit 7dfae59

File tree

11 files changed

+30303
-93
lines changed

11 files changed

+30303
-93
lines changed

examples/agents/ppo-chess.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import tensordict.nn
6+
import torch
7+
import tqdm
8+
from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \
9+
ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq
10+
from torch import nn
11+
from torch.nn.utils import clip_grad_norm_
12+
from torch.optim import Adam
13+
14+
from torchrl.collectors import SyncDataCollector
15+
16+
from torchrl.envs import ChessEnv, Tokenizer
17+
from torchrl.modules import MLP
18+
from torchrl.modules.distributions import MaskedCategorical
19+
from torchrl.objectives import ClipPPOLoss
20+
from torchrl.objectives.value import GAE
21+
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
22+
23+
tensordict.nn.set_composite_lp_aggregate(False)
24+
25+
num_epochs = 10
26+
batch_size = 256
27+
frames_per_batch = 2048
28+
29+
env = ChessEnv(include_legal_moves=True, include_fen=True)
30+
31+
# tokenize the fen - assume max 70 elements
32+
transform = Tokenizer(in_keys=["fen"], out_keys=["fen_tokenized"], max_length=70)
33+
34+
env = env.append_transform(transform)
35+
n = env.action_spec.n
36+
print(env.rollout(10000))
37+
38+
# Embedding layer for the legal moves
39+
embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64)
40+
41+
# Embedding for the fen
42+
embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64)
43+
44+
backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU)
45+
46+
actor_head = nn.Linear(512, env.action_spec.n)
47+
actor_head.bias.data.fill_(0)
48+
49+
critic_head = nn.Linear(512, 1)
50+
critic_head.bias.data.fill_(0)
51+
52+
prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True)
53+
54+
def make_mask(idx):
55+
mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool)
56+
return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1]
57+
58+
actor = TDProbSeq(
59+
TDMod(
60+
make_mask,
61+
in_keys=["legal_moves"], out_keys=["mask"]),
62+
TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]),
63+
TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]),
64+
TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"],
65+
out_keys=["features"]),
66+
TDMod(backbone, in_keys=["features"], out_keys=["hidden"]),
67+
TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]),
68+
prob,
69+
)
70+
critic = TDSeq(
71+
TDMod(critic_head, in_keys=["hidden"], out_keys=["state_value"]),
72+
)
73+
74+
75+
print(env.rollout(3, actor))
76+
# loss
77+
loss = ClipPPOLoss(actor, critic)
78+
79+
optim = Adam(loss.parameters())
80+
81+
gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True)
82+
83+
# Create a data collector
84+
collector = SyncDataCollector(
85+
create_env_fn=env,
86+
policy=actor,
87+
frames_per_batch=frames_per_batch,
88+
total_frames=1_000_000,
89+
)
90+
91+
replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
92+
replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
93+
94+
for data in tqdm.tqdm(collector):
95+
data = data.filter_non_tensor_data()
96+
print('data', data[0::2])
97+
for i in range(num_epochs):
98+
replay_buffer0.empty()
99+
replay_buffer1.empty()
100+
with torch.no_grad():
101+
# player 0
102+
data0 = gae(data[0::2])
103+
# player 1
104+
data1 = gae(data[1::2])
105+
if i == 0:
106+
print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
107+
print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
108+
109+
replay_buffer0.extend(data0)
110+
replay_buffer1.extend(data1)
111+
112+
n_iter = collector.frames_per_batch//(2 * batch_size)
113+
for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter):
114+
loss_vals = (loss(d0) + loss(d1)) / 2
115+
loss_vals.sum(reduce=True).backward()
116+
gn = clip_grad_norm_(loss.parameters(), 100.0)
117+
optim.step()
118+
optim.zero_grad()

test/mocking_classes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,17 +1070,20 @@ def _step(
10701070

10711071
class CountingEnvWithString(CountingEnv):
10721072
def __init__(self, *args, **kwargs):
1073+
self.max_size = kwargs.pop("max_size", 30)
1074+
self.min_size = kwargs.pop("min_size", 4)
10731075
super().__init__(*args, **kwargs)
10741076
self.observation_spec.set(
10751077
"string",
10761078
NonTensor(
10771079
shape=self.batch_size,
10781080
device=self.device,
1081+
example_data=self.get_random_string(),
10791082
),
10801083
)
10811084

10821085
def get_random_string(self):
1083-
size = random.randint(4, 30)
1086+
size = random.randint(self.min_size, self.max_size)
10841087
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
10851088

10861089
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:

test/test_specs.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2):
14021402
assert spec2.zero().shape == spec2.shape
14031403

14041404
def test_non_tensor(self):
1405-
spec = NonTensor((3, 4), device="cpu")
1405+
spec = NonTensor((3, 4), device="cpu", example_data="example_data")
14061406
assert (
14071407
spec.expand(2, 3, 4)
14081408
== spec.expand((2, 3, 4))
1409-
== NonTensor((2, 3, 4), device="cpu")
1409+
== NonTensor((2, 3, 4), device="cpu", example_data="example_data")
14101410
)
1411+
assert spec.expand(2, 3, 4).example_data == "example_data"
14111412

14121413
@pytest.mark.parametrize("shape1", [None, (), (5,)])
14131414
@pytest.mark.parametrize("shape2", [(), (10,)])
@@ -1607,9 +1608,10 @@ def test_multionehot(
16071608
assert spec is not spec.clone()
16081609

16091610
def test_non_tensor(self):
1610-
spec = NonTensor(shape=(3, 4), device="cpu")
1611+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
16111612
assert spec.clone() == spec
16121613
assert spec.clone() is not spec
1614+
assert spec.clone().example_data == "example_data"
16131615

16141616
@pytest.mark.parametrize("shape1", [None, (), (5,)])
16151617
def test_onehot(
@@ -1840,9 +1842,10 @@ def test_multionehot(
18401842
spec.unbind(-1)
18411843

18421844
def test_non_tensor(self):
1843-
spec = NonTensor(shape=(3, 4), device="cpu")
1845+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
18441846
assert spec.unbind(1)[0] == spec[:, 0]
18451847
assert spec.unbind(1)[0] is not spec[:, 0]
1848+
assert spec.unbind(1)[0].example_data == "example_data"
18461849

18471850
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
18481851
def test_onehot(
@@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device):
20012004
assert spec.to(device).device == device
20022005

20032006
def test_non_tensor(self, device):
2004-
spec = NonTensor(shape=(3, 4), device="cpu")
2007+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
20052008
assert spec.to(device).device == device
2009+
assert spec.to(device).example_data == "example_data"
20062010

20072011
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
20082012
def test_onehot(self, shape1, device):
@@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
22622266
assert r.shape == c.shape
22632267

22642268
def test_stack_non_tensor(self, shape, stack_dim):
2265-
spec0 = NonTensor(shape=shape, device="cpu")
2266-
spec1 = NonTensor(shape=shape, device="cpu")
2269+
spec0 = NonTensor(shape=shape, device="cpu", example_data="example_data")
2270+
spec1 = NonTensor(shape=shape, device="cpu", example_data="example_data")
22672271
new_spec = torch.stack([spec0, spec1], stack_dim)
22682272
shape_insert = list(shape)
22692273
shape_insert.insert(stack_dim, 2)
22702274
assert new_spec.shape == torch.Size(shape_insert)
22712275
assert new_spec.device == torch.device("cpu")
2276+
assert new_spec.example_data == "example_data"
22722277

22732278
def test_stack_onehot(self, shape, stack_dim):
22742279
n = 5
@@ -3642,10 +3647,18 @@ def test_expand(self):
36423647

36433648
class TestNonTensorSpec:
36443649
def test_sample(self):
3645-
nts = NonTensor(shape=(3, 4))
3650+
nts = NonTensor(shape=(3, 4), example_data="example_data")
36463651
assert nts.one((2,)).shape == (2, 3, 4)
36473652
assert nts.rand((2,)).shape == (2, 3, 4)
36483653
assert nts.zero((2,)).shape == (2, 3, 4)
3654+
assert nts.one((2,)).data == "example_data"
3655+
assert nts.rand((2,)).data == "example_data"
3656+
assert nts.zero((2,)).data == "example_data"
3657+
3658+
def test_example_data_ineq(self):
3659+
nts0 = NonTensor(shape=(3, 4), example_data="example_data")
3660+
nts1 = NonTensor(shape=(3, 4), example_data="example_data 2")
3661+
assert nts0 != nts1
36493662

36503663

36513664
@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")

0 commit comments

Comments
 (0)