Open
Description
functorch does not play well with TensorDictModules (especially when callingvmap
).
Therefore we need to functionalize the modules inside the TensorDictModule
, and then build some interface between the TensorDictModule
and the functional module.
For instance, TDSequence
has to split the parameters according to each of the submodules etc.
We should test all of these functionalities:
- creating
FunctionalModuleWithBuffers
from the TDModule OR providingFunctionalModuleWithBuffers
to a TDModule upon construction - Nested TDModules (e.g.
ProbabilisticTensorDictModule
): param length, casting etc. - TDSequence param split
- [upcoming feature] TDSequence param split: caching of param lengths
-
FunctionalModule
andFunctionalModuleWithBuffers
(some logics might break with the former) - vmap extension given number of input to module