-
Notifications
You must be signed in to change notification settings - Fork 32
Bnb/disc training fix #259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… training first batch
… details to compute running means.
… and moved post batch logging to separate method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally i think this refactor makes sense but i think the function stack when calling train() is trending towards massive complexity. I want to challenge you to simplify the stack trace.
sup3r/models/abstract.py
Outdated
|
|
||
| return hi_res | ||
|
|
||
| def _get_hr_exo_and_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could go either way on this, but knee jerk reaction is that breaking these lines out into a separate function in a different file just makes the stack trace deeper for little benefit. This function is only called in one place in a different file in a relatively short parent function. Seems like we could leave it as-is for less nesting functions? My gut feeling is that three direct function calls without any logic is portable enough to not be packaged into a separate function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of these extractions are motivated by the work on models with observations. I could delay this until that PR if you prefer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess i would have to see that work too, but i'd be shocked if a 14 line function really helps reduce the burden of 3 function calls? I really think we should just call the 3 functions directly. More nested functions reduces docstring quality and makes it way harder to trace args/kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This removes ~50 lines of duplication in the obs branch but we can decide if it's worth doing in that PR.
sup3r/models/base.py
Outdated
| ) | ||
| return self._val_record.mean(axis=0) | ||
|
|
||
| def _get_batch_loss_details( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot more going on here than just getting the loss details! This is running a full gradient descent step including updating model parameters. Function name and docstring are misleading.
I'm on the fence on this one, i'm not convinced we need to split this out into its own function. for similar reasons (it's only called once, would we ever call this outside of a training loop?). There are quite a few lines but it's not that complicated and neither is the parent function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good point on naming. Same comment on extraction - in the work on models with observations it's helpful to have this pulled out but I can delay this until that PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So now we have _run_gradient_descent and run_gradient_descent and they both take different args and output different things? I don't love that haha. If you simply must have this be a separate function, what about _run_gradient_descent -> _train_batch and then maybe also consider if train_epoch should be hidden _train_epoch to match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I like that better
sup3r/models/base.py
Outdated
| gen_too_good = disc_too_bad | ||
|
|
||
| if not self.generator_weights: | ||
| self.init_weights(batch.low_res.shape, batch.high_res.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor simplification - let's move this function call to the start of train() and put the if not self.generator_weights check in init_weights()
…ain` method. content on extractions in comments.
…queue_shape for this.
…s samples. Added `shapes` property to `AbstractBatchQueue` to use for `init_weights`
Bnb/disc training fix
A running mean of the disc loss is compared to disc loss thresholds for each batch to decide if the discriminator should be trained for that batch. This running mean was not including loss values from batches from the previous epoch, which introduced a jump in the running mean at the start of each epoch. I initially resolved this by just getting the last value in the history but decided that the running mean code was convoluted and needed a rework. This is now handled with a dataframe queue (
self._train_record/self._val_record) of loss details for the past N batches, where N is the number of batches per epoch. The running mean is easily computed by callingself._train_record.mean().