-
Notifications
You must be signed in to change notification settings - Fork 41
Open
Description
Lines 72 to 78 in 85899d7
def step(self, closure=None): | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: continue | |
self.state[p]['shared_steps'] += 1 | |
self.state[p]['step'] = self.state[p]['shared_steps'][0] - 1 # a "step += 1" comes later | |
super.step(closure) |
Why did you override the default implementation of step(closure)? The default one calculates exponential moving average. Your implementation doesn't calculate the step count because it always returns None. I looked over torch's documentation for step() but couldn't understand exactly why you chose to overide the step function.
Kindly review the following PR: #9
Metadata
Metadata
Assignees
Labels
No labels