diff --git a/test/data/test_tensor_frame.py b/test/data/test_tensor_frame.py index a399ce20f..a713fde7f 100644 --- a/test/data/test_tensor_frame.py +++ b/test/data/test_tensor_frame.py @@ -206,3 +206,27 @@ def test_get_col_feat(get_fake_tensor_frame): else: assert torch.allclose(torch.cat(feat_list, dim=1), tf.feat_dict[stype]) + + +def test_custom_tf_get_col_feat(): + col_names_dict = { + 'categorical': ['cat_1', 'cat_2', 'cat_3'], + 'numerical': ['num_1', 'num_2'], + } + feat_dict = { + 'categorical': torch.randint(0, 3, size=(10, 3)), + 'numerical': torch.randn(10, 2), + } + + tf = TensorFrame(feat_dict=feat_dict, col_names_dict=col_names_dict) + + feat = tf.get_col_feat('cat_1') + assert torch.equal(feat, feat_dict['categorical'][:, 0:1]) + feat = tf.get_col_feat('cat_2') + assert torch.equal(feat, feat_dict['categorical'][:, 1:2]) + feat = tf.get_col_feat('cat_3') + assert torch.equal(feat, feat_dict['categorical'][:, 2:3]) + feat = tf.get_col_feat('num_1') + assert torch.equal(feat, feat_dict['numerical'][:, 0:1]) + feat = tf.get_col_feat('num_2') + assert torch.equal(feat, feat_dict['numerical'][:, 1:2]) diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index c9946b37a..3d5788b13 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -141,25 +141,24 @@ def get_col_feat(self, col_name: str) -> TensorData: is :obj:`[num_rows, 1, *]`. """ if col_name not in self._col_to_stype_idx: - raise ValueError( - f"{col_name} is not available in the TensorFrame object.") + raise ValueError(f"'{col_name}' is not available in the " + f"'{self.__class__.__name__}' object") + stype_name, idx = self._col_to_stype_idx[col_name] + feat = self.feat_dict[stype_name] - if stype_name.use_dict_multi_nested_tensor: - assert isinstance(feat, dict) + if isinstance(feat, dict): col_feat: dict[str, MultiNestedTensor] = {} for key, mnt in feat.items(): value = mnt[:, idx] assert isinstance(value, MultiNestedTensor) col_feat[key] = value return col_feat + elif isinstance(feat, _MultiTensor): + return feat[:, idx] else: - if stype_name.use_multi_tensor: - assert isinstance(feat, _MultiTensor) - return feat[:, idx] - else: - assert isinstance(feat, Tensor) - return feat[:, idx].unsqueeze(1) + assert isinstance(feat, Tensor) + return feat[:, idx].unsqueeze(1) @property def stypes(self) -> list[torch_frame.stype]: