Numeric Embedding Implementation Strategy #8
Replies: 3 comments 1 reply
-
One idea is to make the manager a factory. It would then be passed into various init functions for component classes, and then produce a position-encoding class in the necessary components. That way you still mostly handle everything from one class, it's where it needs to be throughout the program, and technically you could turn it on/off for different sections of your model (if you wanted to try using ALiBi in one layer but not the others, for instance). |
Beta Was this translation helpful? Give feedback.
-
Alternative idea, and one I'm pretty sold on. Part of the clunkiness of the manager class is the fact that you call each existing type of positional embedding strategy in the class, with passthrough classes inserted where you don't want to use a particular strategy. This would get pretty unreadable pretty quick with the addition of new strategies - each new addition would require not only creating a new class, but updating the manager model to account for it. Given that end users likely will not care to use multiple positional encoding strategies, the manager class will grow bloated over time with code that will likely not be relevant to the end user. I propose a new manager class structure that does not explicitly call every currently existing positional encoding strategy (with them being on or off). Rather, I make abstract classes that relate to the different times when the manager class would be called. A positional encoding strategy that would be used at that time would need to implement the connected abstract class. These abstract classes will have a call function of some kind. When initializing the manager class, it has a list of encoding instances passed in by the end user. When the manager calls a specific function, it loops over the encodings it has been provided. If said encoding has an abstract parent class relevant to the function being called, it calls the function expected of that abstract class. Otherwise, it does nothing. This ultimately simplifies the code dramatically, putting the onus of customization and inclusion on the use of the positional encoding strategy class, leaving the manager code largely untouched. Aside from bug fixes and the possible inclusion of new positional strategy applications heretofore unseen, the manager code would be left completely untouched. We could get rid of those passthrough classes I created earlier, and for an NAS the question revolves around whether or not you include a given strategy in the list provided to the manager at initialization. Theoretically, if a positional encoding needs to be applied in multiple instances, it would need to inherit from multiple abstract parent classes, or the given abstract class would need to implement multiple functions to call. I don't see where that could happen with current approaches, but it's a possibility to consider. |
Beta Was this translation helpful? Give feedback.
-
One idea I explored but ultimately rejected was to use a hook of some kind to auto-populate each component with the numeric embedding manager in advance so you didn't need to pass it in the forward pass. However, the class MyLightningModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.manager = Manager()
self.components = torch.nn.ModuleList([Component() for _ in range(3)])
self.components.apply(lambda module: module.register_forward_pre_hook(self._inject_manager))
def _inject_manager(self, module, inputs):
module.manager = self.manager
return inputs
def forward(self, x):
return sum(component(x) for component in self.components) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Currently, each numeric/position embedding strategy is run through a single class, NumericEmbeddingManager. An instance of this class is run through the forward pass of the model.
This solution works, and allows you to access or adjust all numeric/position embedding strategies in a single location (a huge plus). But it feels clunky to pass the whole manager through the forward pass every time. I've been looking through design pattern strategies for a solution to this problem, but haven't landed on a simpler, more elegant solution yet.
Anyone have any ideas on ways to improve this?
Beta Was this translation helpful? Give feedback.
All reactions