Skip to content

Commit db1a7d4

Browse files
authored
[Lint] Add TorchFix linter (#1580)
1 parent a02679b commit db1a7d4

File tree

4 files changed

+15
-6
lines changed

4 files changed

+15
-6
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ repos:
2727
additional_dependencies:
2828
- flake8-bugbear==22.10.27
2929
- flake8-comprehensions==3.10.1
30-
30+
- torchfix==0.0.2
3131

3232
- repo: https://github.com/PyCQA/pydocstyle
3333
rev: 6.1.1

setup.cfg

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,18 @@ per-file-ignores =
1515
test/smoke_test_deps.py: F401
1616
test_*.py: F841, E731, E266
1717
test/opengl_rendering.py: F401
18+
test/test_modules.py: F841, E731, E266, TOR101
19+
test/test_tensordictmodules.py: F841, E731, E266, TOR101
20+
torchrl/objectives/cql.py: TOR101
21+
torchrl/objectives/deprecated.py: TOR101
22+
torchrl/objectives/iql.py: TOR101
23+
torchrl/objectives/redq.py: TOR101
24+
torchrl/objectives/sac.py: TOR101
25+
torchrl/objectives/td3.py: TOR101
26+
torchrl/objectives/value/advantages.py: TOR101
1827

1928
exclude = venv
20-
extend-select = B901, C401, C408, C409
29+
extend-select = B901, C401, C408, C409, TOR0, TOR1, TOR2
2130

2231
[pydocstyle]
2332
;select = D417 # Missing argument descriptions in the docstring

test/test_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6723,8 +6723,8 @@ def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa
67236723
with pytest.raises(AssertionError):
67246724
torch.testing.assert_close(cur_embedding[:, 1:], last_embedding[:, :-1])
67256725

6726-
explicit_reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - (
6727-
-torch.norm(last_embedding - goal_embedding, dim=-1)
6726+
explicit_reward = -torch.linalg.norm(cur_embedding - goal_embedding, dim=-1) - (
6727+
-torch.linalg.norm(last_embedding - goal_embedding, dim=-1)
67286728
)
67296729
torch.testing.assert_close(explicit_reward, td["next", "reward"].squeeze())
67306730

torchrl/envs/transforms/vip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ def _step(
380380
cur_embedding = next_tensordict.get(self.out_keys[0])
381381
if last_embedding is not None:
382382
goal_embedding = tensordict["goal_embedding"]
383-
reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - (
384-
-torch.norm(last_embedding - goal_embedding, dim=-1)
383+
reward = -torch.linalg.norm(cur_embedding - goal_embedding, dim=-1) - (
384+
-torch.linalg.norm(last_embedding - goal_embedding, dim=-1)
385385
)
386386
next_tensordict.set("reward", reward)
387387
return next_tensordict

0 commit comments

Comments
 (0)