Skip to content

Conversation

@bnb32
Copy link
Collaborator

@bnb32 bnb32 commented Feb 5, 2025

A running mean of the disc loss is compared to disc loss thresholds for each batch to decide if the discriminator should be trained for that batch. This running mean was not including loss values from batches from the previous epoch, which introduced a jump in the running mean at the start of each epoch. I initially resolved this by just getting the last value in the history but decided that the running mean code was convoluted and needed a rework. This is now handled with a dataframe queue (self._train_record / self._val_record) of loss details for the past N batches, where N is the number of batches per epoch. The running mean is easily computed by calling self._train_record.mean().

@bnb32 bnb32 requested a review from grantbuster February 6, 2025 16:01
Copy link
Member

@grantbuster grantbuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally i think this refactor makes sense but i think the function stack when calling train() is trending towards massive complexity. I want to challenge you to simplify the stack trace.


return hi_res

def _get_hr_exo_and_loss(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could go either way on this, but knee jerk reaction is that breaking these lines out into a separate function in a different file just makes the stack trace deeper for little benefit. This function is only called in one place in a different file in a relatively short parent function. Seems like we could leave it as-is for less nesting functions? My gut feeling is that three direct function calls without any logic is portable enough to not be packaged into a separate function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of these extractions are motivated by the work on models with observations. I could delay this until that PR if you prefer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess i would have to see that work too, but i'd be shocked if a 14 line function really helps reduce the burden of 3 function calls? I really think we should just call the 3 functions directly. More nested functions reduces docstring quality and makes it way harder to trace args/kwargs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes ~50 lines of duplication in the obs branch but we can decide if it's worth doing in that PR.

)
return self._val_record.mean(axis=0)

def _get_batch_loss_details(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot more going on here than just getting the loss details! This is running a full gradient descent step including updating model parameters. Function name and docstring are misleading.

I'm on the fence on this one, i'm not convinced we need to split this out into its own function. for similar reasons (it's only called once, would we ever call this outside of a training loop?). There are quite a few lines but it's not that complicated and neither is the parent function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point on naming. Same comment on extraction - in the work on models with observations it's helpful to have this pulled out but I can delay this until that PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now we have _run_gradient_descent and run_gradient_descent and they both take different args and output different things? I don't love that haha. If you simply must have this be a separate function, what about _run_gradient_descent -> _train_batch and then maybe also consider if train_epoch should be hidden _train_epoch to match.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I like that better

gen_too_good = disc_too_bad

if not self.generator_weights:
self.init_weights(batch.low_res.shape, batch.high_res.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor simplification - let's move this function call to the start of train() and put the if not self.generator_weights check in init_weights()

@bnb32 bnb32 merged commit dfee45d into main Feb 21, 2025
12 checks passed
@bnb32 bnb32 deleted the bnb/disc_training_fix branch February 21, 2025 23:02
github-actions bot pushed a commit that referenced this pull request Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants