@@ -13686,7 +13686,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
13686
13686
assert target_val.device == source_val.device, key
13687
13687
if target_val.dtype == torch.long:
13688
13688
continue
13689
- d0 += (target_val - source_val).norm().item()
13689
+ with torch.no_grad():
13690
+ d0 += (target_val - source_val).norm().item()
13690
13691
13691
13692
assert d0 > 0
13692
13693
if mode == "hard":
@@ -13700,7 +13701,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
13700
13701
target_val = upd._targets[key]
13701
13702
if target_val.dtype == torch.long:
13702
13703
continue
13703
- d1 += (target_val - source_val).norm().item()
13704
+ with torch.no_grad():
13705
+ d1 += (target_val - source_val).norm().item()
13704
13706
13705
13707
assert d1 == d0, i
13706
13708
assert upd.counter == i
@@ -13715,7 +13717,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
13715
13717
target_val = upd._targets[key]
13716
13718
if target_val.dtype == torch.long:
13717
13719
continue
13718
- d1 += (target_val - source_val).norm().item()
13720
+ with torch.no_grad():
13721
+ d1 += (target_val - source_val).norm().item()
13719
13722
assert d1 < d0
13720
13723
13721
13724
elif mode == "soft":
@@ -13728,7 +13731,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
13728
13731
target_val = upd._targets[key]
13729
13732
if target_val.dtype == torch.long:
13730
13733
continue
13731
- d1 += (target_val - source_val).norm().item()
13734
+ with torch.no_grad():
13735
+ d1 += (target_val - source_val).norm().item()
13732
13736
assert d1 < d0
13733
13737
with pytest.warns(UserWarning, match="already"):
13734
13738
upd.init_()
@@ -13741,7 +13745,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
13741
13745
target_val = upd._targets[key]
13742
13746
if target_val.dtype == torch.long:
13743
13747
continue
13744
- d2 += (target_val - source_val).norm().item()
13748
+ with torch.no_grad():
13749
+ d2 += (target_val - source_val).norm().item()
13745
13750
assert d2 < 1e-6
13746
13751
13747
13752
@@ -16668,17 +16673,17 @@ class TestPPO4LLMs:
16668
16673
@pytest.mark.parametrize("from_text", [True, False])
16669
16674
def test_hf(self, from_text):
16670
16675
from torchrl.envs import LLMEnv, Transform
16671
- from torchrl.modules import from_hf_transformers
16676
+ from torchrl.modules import TransformersWrapper
16672
16677
from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
16673
16678
16674
16679
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
16675
16680
tokenizer.pad_token = tokenizer.eos_token
16676
16681
16677
16682
model = OPTForCausalLM(OPTConfig())
16678
- policy_inference = from_hf_transformers (
16683
+ policy_inference = TransformersWrapper (
16679
16684
model, tokenizer=tokenizer, generate=True, from_text=from_text
16680
16685
)
16681
- policy_train = from_hf_transformers (
16686
+ policy_train = TransformersWrapper (
16682
16687
model, tokenizer=tokenizer, generate=False, from_text=False
16683
16688
)
16684
16689
for p in policy_train.parameters():
0 commit comments