Skip to content

Commit 7c43aef

Browse files
authored
[BugFix]: Safe state normalization when std=0 (#323)
1 parent 53bee4b commit 7c43aef

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchrl/trainers/helpers/envs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,9 @@ def get_stats_random_rollout(
371371
s = td_stats.get(key).std().clamp_min(1e-5)
372372
else:
373373
m = td_stats.get(key).mean(dim=0)
374-
s = td_stats.get(key).std(dim=0).clamp_min(1e-5)
374+
s = td_stats.get(key).std(dim=0)
375+
m[s == 0] = 0.0
376+
s[s == 0] = 1.0
375377

376378
print(
377379
f"stats computed for {td_stats.numel()} steps. Got: \n"

0 commit comments

Comments
 (0)