pip install lightning-trainer-utils
- The model wrapper uses the forward function as follows:
output = self.model(**x, **self.forward_kwargs)
return ModelOuput(**output)
It expects batch
as dict
and returns a dict
with keys [loss, report, output]
.
- ML model should return a dict with the following keys:
loss
report
output
[optional]
batch_step = num_samples / (batch_size * num_devices) trainer_global_step = num_samples / (batch_size * num_devices * grad_accumulation)
SaveCheckpoint
also use trainer_global_step
.