5
5
from __future__ import annotations
6
6
7
7
import argparse
8
+ import asyncio
8
9
import contextlib
9
10
import functools
10
11
import gc
11
-
12
12
import importlib
13
13
import os
14
14
import subprocess
52
52
MultiSyncDataCollector ,
53
53
)
54
54
55
- from torchrl .collectors .llm_collector import LLMCollector
55
+ from torchrl .collectors .llm import LLMCollector
56
56
from torchrl .collectors .utils import split_trajectories
57
57
from torchrl .data import (
58
58
Composite ,
@@ -3391,11 +3391,11 @@ def test_collector_rb_sync(self):
3391
3391
assert assert_allclose_td (rbdata0 , rbdata1 )
3392
3392
3393
3393
@pytest .mark .skipif (not _has_gym , reason = "requires gym." )
3394
- @pytest .mark .parametrize ("replay_buffer_chunk " , [False , True ])
3394
+ @pytest .mark .parametrize ("extend_buffer " , [False , True ])
3395
3395
@pytest .mark .parametrize ("env_creator" , [False , True ])
3396
3396
@pytest .mark .parametrize ("storagetype" , [LazyTensorStorage , LazyMemmapStorage ])
3397
3397
def test_collector_rb_multisync (
3398
- self , replay_buffer_chunk , env_creator , storagetype , tmpdir
3398
+ self , extend_buffer , env_creator , storagetype , tmpdir
3399
3399
):
3400
3400
if not env_creator :
3401
3401
env = GymEnv (CARTPOLE_VERSIONED ()).append_transform (StepCounter ())
@@ -3420,7 +3420,7 @@ def test_collector_rb_multisync(
3420
3420
replay_buffer = rb ,
3421
3421
total_frames = 256 ,
3422
3422
frames_per_batch = 32 ,
3423
- replay_buffer_chunk = replay_buffer_chunk ,
3423
+ extend_buffer = extend_buffer ,
3424
3424
)
3425
3425
torch .manual_seed (0 )
3426
3426
pred_len = 0
@@ -3430,7 +3430,7 @@ def test_collector_rb_multisync(
3430
3430
assert len (rb ) == pred_len
3431
3431
collector .shutdown ()
3432
3432
assert len (rb ) == 256
3433
- if not replay_buffer_chunk :
3433
+ if not extend_buffer :
3434
3434
steps_counts = rb ["step_count" ].squeeze ().split (16 )
3435
3435
collector_ids = rb ["collector" , "traj_ids" ].squeeze ().split (16 )
3436
3436
for step_count , ids in zip (steps_counts , collector_ids ):
@@ -3442,11 +3442,11 @@ def test_collector_rb_multisync(
3442
3442
assert (idsdiff >= 0 ).all ()
3443
3443
3444
3444
@pytest .mark .skipif (not _has_gym , reason = "requires gym." )
3445
- @pytest .mark .parametrize ("replay_buffer_chunk " , [False , True ])
3445
+ @pytest .mark .parametrize ("extend_buffer " , [False , True ])
3446
3446
@pytest .mark .parametrize ("env_creator" , [False , True ])
3447
3447
@pytest .mark .parametrize ("storagetype" , [LazyTensorStorage , LazyMemmapStorage ])
3448
3448
def test_collector_rb_multiasync (
3449
- self , replay_buffer_chunk , env_creator , storagetype , tmpdir
3449
+ self , extend_buffer , env_creator , storagetype , tmpdir
3450
3450
):
3451
3451
if not env_creator :
3452
3452
env = GymEnv (CARTPOLE_VERSIONED ()).append_transform (StepCounter ())
@@ -3471,7 +3471,7 @@ def test_collector_rb_multiasync(
3471
3471
replay_buffer = rb ,
3472
3472
total_frames = 256 ,
3473
3473
frames_per_batch = 16 ,
3474
- replay_buffer_chunk = replay_buffer_chunk ,
3474
+ extend_buffer = extend_buffer ,
3475
3475
)
3476
3476
torch .manual_seed (0 )
3477
3477
pred_len = 0
@@ -3481,7 +3481,7 @@ def test_collector_rb_multiasync(
3481
3481
assert len (rb ) >= pred_len
3482
3482
collector .shutdown ()
3483
3483
assert len (rb ) == 256
3484
- if not replay_buffer_chunk :
3484
+ if not extend_buffer :
3485
3485
steps_counts = rb ["step_count" ].squeeze ().split (16 )
3486
3486
collector_ids = rb ["collector" , "traj_ids" ].squeeze ().split (16 )
3487
3487
for step_count , ids in zip (steps_counts , collector_ids ):
@@ -3575,6 +3575,18 @@ def vllm_instance(self):
3575
3575
tokenizer .pad_token = tokenizer .eos_token
3576
3576
return llm_model
3577
3577
3578
+ @pytest .fixture (scope = "module" )
3579
+ def vllm_instance_opt (self ):
3580
+ try :
3581
+ import vllm
3582
+ except ImportError :
3583
+ pytest .skip (reason = "missing vllm" )
3584
+
3585
+ llm_model = vllm .LLM ("facebook/opt-125m" )
3586
+ tokenizer = llm_model .get_tokenizer ()
3587
+ tokenizer .pad_token = tokenizer .eos_token
3588
+ return llm_model
3589
+
3578
3590
@pytest .fixture (scope = "module" )
3579
3591
def transformers_instance (self ):
3580
3592
from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
@@ -3618,12 +3630,11 @@ def test_llm_collector_with_transformers(
3618
3630
self ._run_collector_test (total_steps , rb , policy , tokenizer )
3619
3631
3620
3632
def _run_collector_test (self , total_steps , rb , policy , tokenizer ):
3621
- bsz = 1
3633
+ bsz = 4
3622
3634
dataloader = DummyStrDataLoader (bsz )
3623
3635
3624
3636
env = LLMEnv .from_dataloader (
3625
3637
dataloader = dataloader ,
3626
- tokenizer = tokenizer ,
3627
3638
str2str = True ,
3628
3639
batch_size = bsz ,
3629
3640
group_repeats = True ,
@@ -3650,15 +3661,142 @@ def _run_collector_test(self, total_steps, rb, policy, tokenizer):
3650
3661
3651
3662
if rb is not None :
3652
3663
# Now check the buffer
3653
- assert len (rb ) == total_steps
3654
- sample = rb .sample (1 )
3664
+ assert len (rb ) >= total_steps
3665
+ sample = rb .sample (4 )
3666
+ assert sample .shape == (4 ,)
3667
+ assert not sample ._has_exclusive_keys
3655
3668
# Should match length
3656
- assert len (sample ["text" ]) == 1
3669
+ assert len (sample ["text" ]) == 4
3670
+ # assert len(sample["text"][0]) == 10, sample["text"][0]
3657
3671
# Should be non-empty
3658
3672
assert sample ["text_response" ] is not None
3673
+ for i in range (4 ):
3674
+ # Check that there are more chars in the next step
3675
+ assert len (sample ["text" ][i ]) < len (sample ["next" , "text" ][i ])
3659
3676
else :
3660
3677
stack = torch .cat (stack )
3661
- assert stack .numel () == total_steps
3678
+ assert not stack ._has_exclusive_keys
3679
+ assert stack .numel () == max (- (total_steps // - 4 ) * 4 , 4 )
3680
+ stack = stack .view (- 1 )
3681
+ for i in range (stack .numel ()):
3682
+ # Check that there are more chars in the next step
3683
+ assert len (stack ["text" ][i ]) < len (stack ["next" , "text" ][i ])
3684
+ assert collector ._frames >= total_steps
3685
+
3686
+ def test_llm_collector_start (self , vllm_instance ):
3687
+ asyncio .run (self ._async_run_collector_test (vllm_instance ))
3688
+
3689
+ async def _async_run_collector_test (self , vllm_instance ):
3690
+ total_steps = 20
3691
+ policy = vLLMWrapper (vllm_instance )
3692
+ vllm_instance .get_tokenizer ()
3693
+ bsz = 4
3694
+ dataloader = DummyStrDataLoader (bsz )
3695
+
3696
+ env = LLMEnv .from_dataloader (
3697
+ dataloader = dataloader ,
3698
+ str2str = True ,
3699
+ batch_size = bsz ,
3700
+ group_repeats = True ,
3701
+ )
3702
+
3703
+ rb = ReplayBuffer (storage = LazyStackStorage (max_size = total_steps * 2 ))
3704
+ collector = LLMCollector (
3705
+ env = env ,
3706
+ policy_factory = lambda : policy ,
3707
+ steps_per_batch = env .batch_size [0 ],
3708
+ replay_buffer = rb ,
3709
+ total_steps = total_steps ,
3710
+ )
3711
+ collector .start ()
3712
+
3713
+ i = 0
3714
+ wait = 0
3715
+ while True :
3716
+ while not len (rb ):
3717
+ await asyncio .sleep (1 ) # Use asyncio.sleep instead of time.sleep
3718
+ wait += 1
3719
+ if wait > 20 :
3720
+ raise RuntimeError
3721
+ sample = rb .sample (10 )
3722
+ for i in range (sample .numel ()):
3723
+ # Check that there are more chars in the next step
3724
+ assert len (sample ["text" ][i ]) < len (sample ["next" , "text" ][i ])
3725
+ assert not sample ._has_exclusive_keys , sample
3726
+ await asyncio .sleep (0.1 ) # Use asyncio.sleep instead of time.sleep
3727
+ i += 1
3728
+ if i == 5 :
3729
+ break
3730
+ assert collector ._frames >= total_steps
3731
+
3732
+ await collector .async_shutdown ()
3733
+
3734
+ @pytest .mark .slow
3735
+ @pytest .mark .parametrize ("rb" , [False , True ])
3736
+ @pytest .mark .parametrize ("yield_only_last_steps" , [False , True ])
3737
+ def test_llm_collector_completed (
3738
+ self , vllm_instance_opt , rb , yield_only_last_steps
3739
+ ):
3740
+ policy = vLLMWrapper (vllm_instance_opt )
3741
+ tokenizer = vllm_instance_opt .get_tokenizer ()
3742
+ bsz = 4
3743
+ total_steps = 20
3744
+ dataloader = DummyStrDataLoader (bsz )
3745
+
3746
+ env = LLMEnv .from_dataloader (
3747
+ dataloader = dataloader ,
3748
+ str2str = True ,
3749
+ batch_size = bsz ,
3750
+ group_repeats = True ,
3751
+ eos_token_id = tokenizer .eos_token_id ,
3752
+ )
3753
+ # To make sure the env breaks at some point
3754
+ env = env .append_transform (StepCounter (max_steps = 100 ))
3755
+
3756
+ if rb :
3757
+ rb = ReplayBuffer (storage = LazyStackStorage (max_size = total_steps * 2 ))
3758
+ else :
3759
+ rb = None
3760
+ collector = LLMCollector (
3761
+ env = env ,
3762
+ policy_factory = lambda : policy ,
3763
+ steps_per_batch = env .batch_size [0 ],
3764
+ replay_buffer = rb ,
3765
+ total_steps = total_steps ,
3766
+ yield_completed_trajectories = True ,
3767
+ yield_only_last_steps = yield_only_last_steps ,
3768
+ )
3769
+ assert collector .yield_completed_trajectories
3770
+ assert collector .yield_only_last_steps is yield_only_last_steps
3771
+
3772
+ cur_total_steps = 0
3773
+ has_found_one_with_more_steps = False
3774
+ for data in collector :
3775
+ if rb is None :
3776
+ assert data .ndim == 1
3777
+ assert (data ["next" , "step_count" ] < 99 ).all ()
3778
+ cur_total_steps += data .numel ()
3779
+ for i in range (data .numel ()):
3780
+ # Check that there are more chars in the next step
3781
+ assert len (data ["text" ][i ]) < len (data ["next" , "text" ][i ])
3782
+ if yield_only_last_steps :
3783
+ assert data .shape == (1 ,)
3784
+ else :
3785
+ has_found_one_with_more_steps |= data .numel () > 1
3786
+ else :
3787
+ assert data is None
3788
+ sample = rb .sample (5 )
3789
+ for i in range (sample .numel ()):
3790
+ # Check that there are more chars in the next step
3791
+ assert len (sample ["text" ][i ]) < len (sample ["next" , "text" ][i ])
3792
+ assert sample .ndim == 1
3793
+ assert sample .shape == (5 ,)
3794
+ assert (sample ["next" , "step_count" ] < 99 ).all ()
3795
+ cur_total_steps += 1
3796
+ assert collector ._frames >= cur_total_steps
3797
+ if rb is None and not yield_only_last_steps :
3798
+ assert has_found_one_with_more_steps
3799
+ assert collector ._frames >= total_steps
3662
3800
3663
3801
3664
3802
if __name__ == "__main__" :
0 commit comments