forked from NapoleonZero/training
-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Interface Proposal (WIP)
Concept
TrainingStrategyCallback
is an abstract callback interface that contains information on how to structure the training loop. Such information might include:
- Curriculum Learning strategy information (which metric is used to determine 'hardness', how many stages of training to do, which method, etc.)
Interface
TrainingStrategy
might just be a wrapper around TrainingLoop
:
class BaseTrainingStrategy(TrainingCallback):
def __init__(self):
return
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
@abstractmethod
def difficulty(self, x: Tensor) -> Float:
pass
@abstractmethod
def threshold(self, epoch) -> Bool:
pass
def on_train_epoch_end(self, state):
super().on_train_epoch_end(state)
epoch = state.get_state('epoch')
filter_fn = lambda x: self.difficulty(x) <= self.threshold(epoch, ...)
state.update_filter(filter_fn) # updates filter and somehow reloads dataloaders to accomodate new datapoints
class BabyStepCallback(BaseTrainingStrategy):
def __init__(self, buckets: Int, epochs_per_bucket: Int):
super().__init__()
self.buckets = buckets
self.epochs_per_bucket = epochs_per_bucket
# implementations of `difficulty` and `threshold`
def threshold(self, epoch) -> Bool:
....
Examples
training_loop = TrainingLoop(... callbacks = [BabyStep()])
training_loop.run(model, epochs)
Metadata
Metadata
Assignees
Labels
No labels