33from transformers import CLIPModel , CLIPProcessor
44from typing import Optional
55from ..data import FeatureDataset , ImageDataset
6+ from ..tools import check_dataset_output
67
78
89class DeepFeatures :
@@ -43,7 +44,8 @@ def __call__(self, dataset: ImageDataset) -> FeatureDataset:
4344
4445 self .model = self .model .to (self .device )
4546 self .model = self .model .eval ()
46-
47+
48+ check_dataset_output (dataset , check_label = False )
4749 loader = torch .utils .data .DataLoader (
4850 dataset ,
4951 num_workers = self .num_workers ,
@@ -119,6 +121,7 @@ def __call__(self, dataset: ImageDataset) -> FeatureDataset:
119121 # TODO: this is hacky
120122 dataset .transforms = None # Reset transforms.
121123
124+ check_dataset_output (dataset , check_label = False )
122125 loader = torch .utils .data .DataLoader (
123126 dataset ,
124127 num_workers = self .num_workers ,
@@ -127,7 +130,7 @@ def __call__(self, dataset: ImageDataset) -> FeatureDataset:
127130 collate_fn = lambda x : x ,
128131 )
129132 outputs = []
130- for image in tqdm (loader , mininterval = 1 , ncols = 100 ):
133+ for image , _ in tqdm (loader , mininterval = 1 , ncols = 100 ):
131134 with torch .no_grad ():
132135 output = self .model (self .transform (image ).to (self .device )).pooler_output
133136 outputs .append (output .cpu ())
0 commit comments