diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index 3d5788b13..62d999f1b 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -347,7 +347,7 @@ def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame: out.feat_dict = {stype: fn(x) for stype, x in out.feat_dict.items()} if out.y is not None: y = fn(out.y) - assert isinstance(y, Tensor) + assert isinstance(y, (Tensor, MultiNestedTensor)) out.y = y return out