@@ -4861,31 +4861,22 @@ def policy(td):
4861
4861
r_reset = r [..., ::max_steps ]
4862
4862
if not batched :
4863
4863
if str2str :
4864
+ all_strings = r_reset .view (- 1 )[LLMEnv ._DEFAULT_STR_KEY ]
4865
+ assert sum (s == all_strings [0 ] for s in all_strings ) == repeats
4866
+ assert sum (s == all_strings [repeats ] for s in all_strings ) == repeats
4864
4867
assert (
4865
- r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4866
- == r_reset [..., 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4868
+ sum (s == all_strings [repeats * 2 ] for s in all_strings ) == repeats
4867
4869
)
4870
+ else :
4871
+ all_tokens = r_reset .view (- 1 )[LLMEnv ._DEFAULT_TOKEN_KEY ]
4872
+ assert sum ((s == all_tokens [0 ]).all () for s in all_tokens ) == repeats
4868
4873
assert (
4869
- r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4870
- == r_reset [..., 2 ][LLMEnv ._DEFAULT_STR_KEY ]
4874
+ sum ((s == all_tokens [repeats ]).all () for s in all_tokens ) == repeats
4871
4875
)
4872
4876
assert (
4873
- r_reset [..., 0 ][ LLMEnv . _DEFAULT_STR_KEY ]
4874
- != r_reset [..., 3 ][ LLMEnv . _DEFAULT_STR_KEY ]
4877
+ sum (( s == all_tokens [ repeats * 2 ]). all () for s in all_tokens )
4878
+ == repeats
4875
4879
)
4876
- else :
4877
- assert (
4878
- r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4879
- == r_reset [..., 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4880
- ).all ()
4881
- assert (
4882
- r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4883
- == r_reset [..., 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4884
- ).all ()
4885
- assert (
4886
- r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4887
- != r_reset [..., 3 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4888
- ).any ()
4889
4880
else :
4890
4881
# When batched, each block contains the 3 reset packs
4891
4882
if str2str :
0 commit comments