Skip to content

Commit daa67cb

Browse files
author
Vincent Moens
committed
[Refactor] Simplify LLMEnv
ghstack-source-id: 9367bb1 Pull Request resolved: #2897
1 parent 78cd755 commit daa67cb

File tree

14 files changed

+482
-545
lines changed

14 files changed

+482
-545
lines changed

test/mocking_classes.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,6 +2478,8 @@ def _step(
24782478

24792479
class DummyStrDataLoader:
24802480
def __init__(self, batch_size=0):
2481+
if isinstance(batch_size, tuple):
2482+
batch_size = torch.Size(batch_size).numel()
24812483
self.batch_size = batch_size
24822484

24832485
def generate_random_string(self, length=10):
@@ -2489,21 +2491,25 @@ def __iter__(self):
24892491

24902492
def __next__(self):
24912493
if self.batch_size == 0:
2492-
return self.generate_random_string()
2494+
return {"text": self.generate_random_string()}
24932495
else:
2494-
return [self.generate_random_string() for _ in range(self.batch_size)]
2496+
return {
2497+
"text": [self.generate_random_string() for _ in range(self.batch_size)]
2498+
}
24952499

24962500

24972501
class DummyTensorDataLoader:
24982502
def __init__(self, batch_size=0, max_length=10, padding=False):
2503+
if isinstance(batch_size, tuple):
2504+
batch_size = torch.Size(batch_size).numel()
24992505
self.batch_size = batch_size
25002506
self.max_length = max_length
25012507
self.padding = padding
25022508

25032509
def generate_random_tensor(self):
25042510
"""Generate a tensor of random int64 values."""
25052511
length = random.randint(1, self.max_length)
2506-
rt = torch.randint(0, 100, (length,))
2512+
rt = torch.randint(1, 10000, (length,))
25072513
return rt
25082514

25092515
def pad_tensor(self, tensor):
@@ -2517,11 +2523,12 @@ def __iter__(self):
25172523
def __next__(self):
25182524
if self.batch_size == 0:
25192525
tensor = self.generate_random_tensor()
2520-
return self.pad_tensor(tensor) if self.padding else tensor
2526+
tokens = self.pad_tensor(tensor) if self.padding else tensor
25212527
else:
25222528
tensors = [self.generate_random_tensor() for _ in range(self.batch_size)]
25232529
if self.padding:
25242530
tensors = [self.pad_tensor(tensor) for tensor in tensors]
2525-
return torch.stack(tensors)
2531+
tokens = torch.stack(tensors)
25262532
else:
2527-
return tensors
2533+
tokens = tensors
2534+
return {"tokens": tokens, "attention_mask": tokens != 0}

test/test_actors.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,37 +1361,55 @@ def test_vllm_batch_run(self, pad, generate, use_tensorclass, vllm_instance):
13611361
else:
13621362
assert isinstance(tokens, list)
13631363

1364-
def test_vllm_collection(self, vllm_instance):
1364+
@pytest.mark.parametrize("from_text", [True])
1365+
def test_vllm_collection(self, vllm_instance, from_text):
13651366
policy = vLLMWrapper(
13661367
vllm_instance,
13671368
return_log_probs=True,
1368-
generate_kwargs={"max_tokens": 10},
1369+
generate_kwargs={"max_tokens": 32},
1370+
from_text=from_text in (True, None),
13691371
)
1370-
self._run_check_collector(policy)
1372+
tokenizer = vllm_instance.get_tokenizer()
1373+
self._run_check_collector(policy, from_text=from_text, tokenizer=tokenizer)
13711374

13721375
def test_transformers_collection(self):
13731376
...
13741377

13751378
@classmethod
1376-
def env_constructor(cls):
1377-
dl = DummyStrDataLoader(batch_size=32)
1378-
env = LLMEnv.from_dataloader(
1379-
dl,
1380-
batch_size=16,
1381-
repeats=4,
1382-
# str2str=True, group_repeats=True
1383-
)
1384-
assert env.batch_size == (64,)
1385-
return env
1379+
def env_constructor(cls, **kwargs):
1380+
def make():
1381+
# if kwargs.get("from_text", True):
1382+
dl = DummyStrDataLoader(batch_size=32)
1383+
# else:
1384+
# dl = DummyTensorDataLoader(batch_size=32)
1385+
env = LLMEnv.from_dataloader(
1386+
dl,
1387+
batch_size=4,
1388+
repeats=4,
1389+
**kwargs,
1390+
)
1391+
assert env.batch_size == (16,)
1392+
return env
1393+
1394+
return make
13861395

1387-
def _run_check_collector(self, policy):
1396+
def _run_check_collector(self, policy, from_text, tokenizer):
1397+
if from_text is None:
1398+
kwargs = {"eos_token_id": tokenizer.eos_token_id}
1399+
else:
1400+
kwargs = {
1401+
"from_text": from_text,
1402+
"tokenizer": tokenizer,
1403+
"eos_token_id": tokenizer.eos_token_id,
1404+
}
13881405
collector = SyncDataCollector(
1389-
self.env_constructor,
1406+
self.env_constructor(**kwargs),
13901407
policy=policy,
1391-
frames_per_batch=128,
1392-
total_frames=512,
1408+
frames_per_batch=32,
1409+
total_frames=128,
13931410
use_buffers=False,
13941411
)
1412+
t = 0
13951413
for data in collector:
13961414
assert isinstance(data, LazyStackedTensorDict)
13971415
assert isinstance(data.reshape(-1).get("text_response"), NonTensorStack)
@@ -1403,6 +1421,10 @@ def _run_check_collector(self, policy):
14031421
assert ("next", "text") in data
14041422
# tokens
14051423
assert "tokens" in data
1424+
1425+
t += data.numel()
1426+
assert collector._frames == t
1427+
assert t < 512, t
14061428
# assert ("next", "tokens") in data
14071429

14081430
def test_vllm_generate_multiple_trajs(self, vllm_instance):

test/test_collector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3636,9 +3636,10 @@ def _run_collector_test(self, total_steps, rb, policy, tokenizer):
36363636

36373637
env = LLMEnv.from_dataloader(
36383638
dataloader=dataloader,
3639-
str2str=True,
3639+
from_text=True,
36403640
batch_size=bsz,
36413641
group_repeats=True,
3642+
eos_token_id=tokenizer.eos_token_id,
36423643
)
36433644
if rb:
36443645
rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2))
@@ -3695,7 +3696,7 @@ async def test_llm_collector_start(self, vllm_instance):
36953696

36963697
env = LLMEnv.from_dataloader(
36973698
dataloader=dataloader,
3698-
str2str=True,
3699+
from_text=True,
36993700
batch_size=bsz,
37003701
group_repeats=True,
37013702
)
@@ -3748,7 +3749,7 @@ def test_llm_collector_completed(
37483749

37493750
env = LLMEnv.from_dataloader(
37503751
dataloader=dataloader,
3751-
str2str=True,
3752+
from_text=True,
37523753
batch_size=bsz,
37533754
group_repeats=True,
37543755
eos_token_id=tokenizer.eos_token_id,
@@ -3854,7 +3855,7 @@ def test_llm_collector_completed_async(
38543855
def env_maker():
38553856
env = LLMEnv.from_dataloader(
38563857
dataloader=dataloader,
3857-
str2str=True,
3858+
from_text=True,
38583859
batch_size=(),
38593860
group_repeats=True,
38603861
eos_token_id=tokenizer.eos_token_id,

test/test_cost.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16708,7 +16708,8 @@ def test_hf(self, from_text):
1670816708
dl,
1670916709
tokenizer=tokenizer if not from_text else None,
1671016710
batch_size=(32,),
16711-
str2str=True,
16711+
from_text=True,
16712+
eos_token_id=tokenizer.eos_token_id,
1671216713
)
1671316714

1671416715
class RewardTransform(Transform):

0 commit comments

Comments
 (0)