Skip to content

Commit b66fcd4

Browse files
author
Vincent Moens
committed
[BugFix] Fix .item() warning on tensors that require grad
ghstack-source-id: 502bdda Pull Request resolved: #2885
1 parent c9caf3d commit b66fcd4

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

test/test_cost.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13686,7 +13686,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1368613686
assert target_val.device == source_val.device, key
1368713687
if target_val.dtype == torch.long:
1368813688
continue
13689-
d0 += (target_val - source_val).norm().item()
13689+
with torch.no_grad():
13690+
d0 += (target_val - source_val).norm().item()
1369013691

1369113692
assert d0 > 0
1369213693
if mode == "hard":
@@ -13700,7 +13701,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1370013701
target_val = upd._targets[key]
1370113702
if target_val.dtype == torch.long:
1370213703
continue
13703-
d1 += (target_val - source_val).norm().item()
13704+
with torch.no_grad():
13705+
d1 += (target_val - source_val).norm().item()
1370413706

1370513707
assert d1 == d0, i
1370613708
assert upd.counter == i
@@ -13715,7 +13717,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1371513717
target_val = upd._targets[key]
1371613718
if target_val.dtype == torch.long:
1371713719
continue
13718-
d1 += (target_val - source_val).norm().item()
13720+
with torch.no_grad():
13721+
d1 += (target_val - source_val).norm().item()
1371913722
assert d1 < d0
1372013723

1372113724
elif mode == "soft":
@@ -13728,7 +13731,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1372813731
target_val = upd._targets[key]
1372913732
if target_val.dtype == torch.long:
1373013733
continue
13731-
d1 += (target_val - source_val).norm().item()
13734+
with torch.no_grad():
13735+
d1 += (target_val - source_val).norm().item()
1373213736
assert d1 < d0
1373313737
with pytest.warns(UserWarning, match="already"):
1373413738
upd.init_()
@@ -13741,7 +13745,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1374113745
target_val = upd._targets[key]
1374213746
if target_val.dtype == torch.long:
1374313747
continue
13744-
d2 += (target_val - source_val).norm().item()
13748+
with torch.no_grad():
13749+
d2 += (target_val - source_val).norm().item()
1374513750
assert d2 < 1e-6
1374613751

1374713752

@@ -16668,17 +16673,17 @@ class TestPPO4LLMs:
1666816673
@pytest.mark.parametrize("from_text", [True, False])
1666916674
def test_hf(self, from_text):
1667016675
from torchrl.envs import LLMEnv, Transform
16671-
from torchrl.modules import from_hf_transformers
16676+
from torchrl.modules import TransformersWrapper
1667216677
from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
1667316678

1667416679
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
1667516680
tokenizer.pad_token = tokenizer.eos_token
1667616681

1667716682
model = OPTForCausalLM(OPTConfig())
16678-
policy_inference = from_hf_transformers(
16683+
policy_inference = TransformersWrapper(
1667916684
model, tokenizer=tokenizer, generate=True, from_text=from_text
1668016685
)
16681-
policy_train = from_hf_transformers(
16686+
policy_train = TransformersWrapper(
1668216687
model, tokenizer=tokenizer, generate=False, from_text=False
1668316688
)
1668416689
for p in policy_train.parameters():

torchrl/objectives/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _compare_and_expand(param):
385385
p_out = param.expand(expand_dim, *param.shape).clone()
386386
p_out = nn.Parameter(
387387
p_out.uniform_(
388-
p_out.min().item(), p_out.max().item()
388+
p_out.data.min().item(), p_out.data.max().item()
389389
).requires_grad_()
390390
)
391391
return p_out

0 commit comments

Comments
 (0)