From 83a06f5730a8d29eddd1149607f7159353cf1b48 Mon Sep 17 00:00:00 2001 From: XinweiHe Date: Mon, 8 Jul 2024 18:20:02 +0000 Subject: [PATCH 1/3] fix --- torch_frame/data/tensor_frame.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index c9946b37a..5b0f96940 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -348,7 +348,6 @@ def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame: out.feat_dict = {stype: fn(x) for stype, x in out.feat_dict.items()} if out.y is not None: y = fn(out.y) - assert isinstance(y, Tensor) out.y = y return out From d6587b4b081be9d711ff713a25a629ae8ca9086e Mon Sep 17 00:00:00 2001 From: XinweiHe Date: Mon, 8 Jul 2024 18:52:44 +0000 Subject: [PATCH 2/3] fix --- torch_frame/data/tensor_frame.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index 5b0f96940..984595d36 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -348,6 +348,7 @@ def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame: out.feat_dict = {stype: fn(x) for stype, x in out.feat_dict.items()} if out.y is not None: y = fn(out.y) + assert isinstance(y, (torch.Tensor, MultiNestedTensor)) out.y = y return out From 8a2675ea90115df1ce369140ed20ca33bc28c25a Mon Sep 17 00:00:00 2001 From: XinweiHe Date: Mon, 8 Jul 2024 18:53:52 +0000 Subject: [PATCH 3/3] fix --- torch_frame/data/tensor_frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index 984595d36..04b629e2e 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -348,7 +348,7 @@ def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame: out.feat_dict = {stype: fn(x) for stype, x in out.feat_dict.items()} if out.y is not None: y = fn(out.y) - assert isinstance(y, (torch.Tensor, MultiNestedTensor)) + assert isinstance(y, (Tensor, MultiNestedTensor)) out.y = y return out