Skip to content

Implement interface for training scheduling #15

@crybot

Description

@crybot

Interface Proposal (WIP)

Concept

TrainingStrategyCallback is an abstract callback interface that contains information on how to structure the training loop. Such information might include:

  • Curriculum Learning strategy information (which metric is used to determine 'hardness', how many stages of training to do, which method, etc.)

Interface

TrainingStrategy might just be a wrapper around TrainingLoop:

class BaseTrainingStrategy(TrainingCallback):
    def __init__(self):
        return

    @abstractmethod
    def state_dict(self):
        pass

    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

    @abstractmethod
    def difficulty(self, x: Tensor) -> Float:
        pass

    @abstractmethod
    def threshold(self, epoch) -> Bool:
        pass

    def on_train_epoch_end(self, state):
        super().on_train_epoch_end(state)
        epoch = state.get_state('epoch')
        filter_fn = lambda x: self.difficulty(x) <= self.threshold(epoch, ...)
        state.update_filter(filter_fn) # updates filter and somehow reloads dataloaders to accomodate new datapoints

class BabyStepCallback(BaseTrainingStrategy):
    def __init__(self, buckets: Int, epochs_per_bucket: Int):
        super().__init__()
        self.buckets = buckets
        self.epochs_per_bucket = epochs_per_bucket
        

    # implementations of `difficulty` and `threshold`
    def threshold(self, epoch) -> Bool:
        ....

Examples

training_loop = TrainingLoop(... callbacks = [BabyStep()])
training_loop.run(model, epochs)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions