Skip to content

Commit 31df775

Browse files
author
Vincent Moens
committed
[Feature] More options for LLM collectors
ghstack-source-id: 74e8155 Pull Request resolved: #2891
1 parent 4135a83 commit 31df775

File tree

7 files changed

+431
-88
lines changed

7 files changed

+431
-88
lines changed

test/test_collector.py

Lines changed: 154 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from __future__ import annotations
66

77
import argparse
8+
import asyncio
89
import contextlib
910
import functools
1011
import gc
11-
1212
import importlib
1313
import os
1414
import subprocess
@@ -52,7 +52,7 @@
5252
MultiSyncDataCollector,
5353
)
5454

55-
from torchrl.collectors.llm_collector import LLMCollector
55+
from torchrl.collectors.llm import LLMCollector
5656
from torchrl.collectors.utils import split_trajectories
5757
from torchrl.data import (
5858
Composite,
@@ -3391,11 +3391,11 @@ def test_collector_rb_sync(self):
33913391
assert assert_allclose_td(rbdata0, rbdata1)
33923392

33933393
@pytest.mark.skipif(not _has_gym, reason="requires gym.")
3394-
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
3394+
@pytest.mark.parametrize("extend_buffer", [False, True])
33953395
@pytest.mark.parametrize("env_creator", [False, True])
33963396
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
33973397
def test_collector_rb_multisync(
3398-
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
3398+
self, extend_buffer, env_creator, storagetype, tmpdir
33993399
):
34003400
if not env_creator:
34013401
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
@@ -3420,7 +3420,7 @@ def test_collector_rb_multisync(
34203420
replay_buffer=rb,
34213421
total_frames=256,
34223422
frames_per_batch=32,
3423-
replay_buffer_chunk=replay_buffer_chunk,
3423+
extend_buffer=extend_buffer,
34243424
)
34253425
torch.manual_seed(0)
34263426
pred_len = 0
@@ -3430,7 +3430,7 @@ def test_collector_rb_multisync(
34303430
assert len(rb) == pred_len
34313431
collector.shutdown()
34323432
assert len(rb) == 256
3433-
if not replay_buffer_chunk:
3433+
if not extend_buffer:
34343434
steps_counts = rb["step_count"].squeeze().split(16)
34353435
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
34363436
for step_count, ids in zip(steps_counts, collector_ids):
@@ -3442,11 +3442,11 @@ def test_collector_rb_multisync(
34423442
assert (idsdiff >= 0).all()
34433443

34443444
@pytest.mark.skipif(not _has_gym, reason="requires gym.")
3445-
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
3445+
@pytest.mark.parametrize("extend_buffer", [False, True])
34463446
@pytest.mark.parametrize("env_creator", [False, True])
34473447
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
34483448
def test_collector_rb_multiasync(
3449-
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
3449+
self, extend_buffer, env_creator, storagetype, tmpdir
34503450
):
34513451
if not env_creator:
34523452
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
@@ -3471,7 +3471,7 @@ def test_collector_rb_multiasync(
34713471
replay_buffer=rb,
34723472
total_frames=256,
34733473
frames_per_batch=16,
3474-
replay_buffer_chunk=replay_buffer_chunk,
3474+
extend_buffer=extend_buffer,
34753475
)
34763476
torch.manual_seed(0)
34773477
pred_len = 0
@@ -3481,7 +3481,7 @@ def test_collector_rb_multiasync(
34813481
assert len(rb) >= pred_len
34823482
collector.shutdown()
34833483
assert len(rb) == 256
3484-
if not replay_buffer_chunk:
3484+
if not extend_buffer:
34853485
steps_counts = rb["step_count"].squeeze().split(16)
34863486
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
34873487
for step_count, ids in zip(steps_counts, collector_ids):
@@ -3575,6 +3575,18 @@ def vllm_instance(self):
35753575
tokenizer.pad_token = tokenizer.eos_token
35763576
return llm_model
35773577

3578+
@pytest.fixture(scope="module")
3579+
def vllm_instance_opt(self):
3580+
try:
3581+
import vllm
3582+
except ImportError:
3583+
pytest.skip(reason="missing vllm")
3584+
3585+
llm_model = vllm.LLM("facebook/opt-125m")
3586+
tokenizer = llm_model.get_tokenizer()
3587+
tokenizer.pad_token = tokenizer.eos_token
3588+
return llm_model
3589+
35783590
@pytest.fixture(scope="module")
35793591
def transformers_instance(self):
35803592
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
@@ -3618,12 +3630,11 @@ def test_llm_collector_with_transformers(
36183630
self._run_collector_test(total_steps, rb, policy, tokenizer)
36193631

36203632
def _run_collector_test(self, total_steps, rb, policy, tokenizer):
3621-
bsz = 1
3633+
bsz = 4
36223634
dataloader = DummyStrDataLoader(bsz)
36233635

36243636
env = LLMEnv.from_dataloader(
36253637
dataloader=dataloader,
3626-
tokenizer=tokenizer,
36273638
str2str=True,
36283639
batch_size=bsz,
36293640
group_repeats=True,
@@ -3650,15 +3661,142 @@ def _run_collector_test(self, total_steps, rb, policy, tokenizer):
36503661

36513662
if rb is not None:
36523663
# Now check the buffer
3653-
assert len(rb) == total_steps
3654-
sample = rb.sample(1)
3664+
assert len(rb) >= total_steps
3665+
sample = rb.sample(4)
3666+
assert sample.shape == (4,)
3667+
assert not sample._has_exclusive_keys
36553668
# Should match length
3656-
assert len(sample["text"]) == 1
3669+
assert len(sample["text"]) == 4
3670+
# assert len(sample["text"][0]) == 10, sample["text"][0]
36573671
# Should be non-empty
36583672
assert sample["text_response"] is not None
3673+
for i in range(4):
3674+
# Check that there are more chars in the next step
3675+
assert len(sample["text"][i]) < len(sample["next", "text"][i])
36593676
else:
36603677
stack = torch.cat(stack)
3661-
assert stack.numel() == total_steps
3678+
assert not stack._has_exclusive_keys
3679+
assert stack.numel() == max(-(total_steps // -4) * 4, 4)
3680+
stack = stack.view(-1)
3681+
for i in range(stack.numel()):
3682+
# Check that there are more chars in the next step
3683+
assert len(stack["text"][i]) < len(stack["next", "text"][i])
3684+
assert collector._frames >= total_steps
3685+
3686+
def test_llm_collector_start(self, vllm_instance):
3687+
asyncio.run(self._async_run_collector_test(vllm_instance))
3688+
3689+
async def _async_run_collector_test(self, vllm_instance):
3690+
total_steps = 20
3691+
policy = vLLMWrapper(vllm_instance)
3692+
vllm_instance.get_tokenizer()
3693+
bsz = 4
3694+
dataloader = DummyStrDataLoader(bsz)
3695+
3696+
env = LLMEnv.from_dataloader(
3697+
dataloader=dataloader,
3698+
str2str=True,
3699+
batch_size=bsz,
3700+
group_repeats=True,
3701+
)
3702+
3703+
rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2))
3704+
collector = LLMCollector(
3705+
env=env,
3706+
policy_factory=lambda: policy,
3707+
steps_per_batch=env.batch_size[0],
3708+
replay_buffer=rb,
3709+
total_steps=total_steps,
3710+
)
3711+
collector.start()
3712+
3713+
i = 0
3714+
wait = 0
3715+
while True:
3716+
while not len(rb):
3717+
await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep
3718+
wait += 1
3719+
if wait > 20:
3720+
raise RuntimeError
3721+
sample = rb.sample(10)
3722+
for i in range(sample.numel()):
3723+
# Check that there are more chars in the next step
3724+
assert len(sample["text"][i]) < len(sample["next", "text"][i])
3725+
assert not sample._has_exclusive_keys, sample
3726+
await asyncio.sleep(0.1) # Use asyncio.sleep instead of time.sleep
3727+
i += 1
3728+
if i == 5:
3729+
break
3730+
assert collector._frames >= total_steps
3731+
3732+
await collector.async_shutdown()
3733+
3734+
@pytest.mark.slow
3735+
@pytest.mark.parametrize("rb", [False, True])
3736+
@pytest.mark.parametrize("yield_only_last_steps", [False, True])
3737+
def test_llm_collector_completed(
3738+
self, vllm_instance_opt, rb, yield_only_last_steps
3739+
):
3740+
policy = vLLMWrapper(vllm_instance_opt)
3741+
tokenizer = vllm_instance_opt.get_tokenizer()
3742+
bsz = 4
3743+
total_steps = 20
3744+
dataloader = DummyStrDataLoader(bsz)
3745+
3746+
env = LLMEnv.from_dataloader(
3747+
dataloader=dataloader,
3748+
str2str=True,
3749+
batch_size=bsz,
3750+
group_repeats=True,
3751+
eos_token_id=tokenizer.eos_token_id,
3752+
)
3753+
# To make sure the env breaks at some point
3754+
env = env.append_transform(StepCounter(max_steps=100))
3755+
3756+
if rb:
3757+
rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2))
3758+
else:
3759+
rb = None
3760+
collector = LLMCollector(
3761+
env=env,
3762+
policy_factory=lambda: policy,
3763+
steps_per_batch=env.batch_size[0],
3764+
replay_buffer=rb,
3765+
total_steps=total_steps,
3766+
yield_completed_trajectories=True,
3767+
yield_only_last_steps=yield_only_last_steps,
3768+
)
3769+
assert collector.yield_completed_trajectories
3770+
assert collector.yield_only_last_steps is yield_only_last_steps
3771+
3772+
cur_total_steps = 0
3773+
has_found_one_with_more_steps = False
3774+
for data in collector:
3775+
if rb is None:
3776+
assert data.ndim == 1
3777+
assert (data["next", "step_count"] < 99).all()
3778+
cur_total_steps += data.numel()
3779+
for i in range(data.numel()):
3780+
# Check that there are more chars in the next step
3781+
assert len(data["text"][i]) < len(data["next", "text"][i])
3782+
if yield_only_last_steps:
3783+
assert data.shape == (1,)
3784+
else:
3785+
has_found_one_with_more_steps |= data.numel() > 1
3786+
else:
3787+
assert data is None
3788+
sample = rb.sample(5)
3789+
for i in range(sample.numel()):
3790+
# Check that there are more chars in the next step
3791+
assert len(sample["text"][i]) < len(sample["next", "text"][i])
3792+
assert sample.ndim == 1
3793+
assert sample.shape == (5,)
3794+
assert (sample["next", "step_count"] < 99).all()
3795+
cur_total_steps += 1
3796+
assert collector._frames >= cur_total_steps
3797+
if rb is None and not yield_only_last_steps:
3798+
assert has_found_one_with_more_steps
3799+
assert collector._frames >= total_steps
36623800

36633801

36643802
if __name__ == "__main__":

0 commit comments

Comments
 (0)