@@ -139,7 +139,7 @@ def __init__(
139139
140140 source_channels = stream_info ["source" ] if "source" in stream_info else None
141141 if source_channels :
142- self .source_channels , self .source_idx = self .selec (source_channels )
142+ self .source_channels , self .source_idx = self .select (source_channels )
143143 else :
144144 self .source_channels = self .colnames
145145 self .source_idx = self .cols_idx
@@ -159,11 +159,10 @@ def select(self, ch_filters: list[str]) -> None:
159159 Allow user to specify which columns they want to access.
160160 Get functions only returned for these specified columns.
161161 """
162-
163162 mask = [np .array ([f in c for f in ch_filters ]).any () for c in self .colnames ]
164163
165- selected_cols_idx = np .where (mask )[0 ]
166- selected_colnames = [self .colnames [i ] for i in selected_cols_idx ]
164+ selected_cols_idx = self . cols_idx [ np .where (mask )[0 ] ]
165+ selected_colnames = [self .colnames [i ] for i in np . where ( mask )[ 0 ] ]
167166
168167 return selected_colnames , selected_cols_idx
169168
@@ -343,7 +342,7 @@ def normalize_target_channels(self, target: torch.tensor) -> torch.tensor:
343342 """
344343 assert target .shape [1 ] == len (self .target_idx )
345344 for i , ch in enumerate (self .target_idx ):
346- target [..., i ] = (target [..., i ] - self .mean [ch + 2 ]) / self .stdev [ch + 2 ]
345+ target [..., i ] = (target [..., i ] - self .mean [ch ]) / self .stdev [ch ]
347346
348347 return target
349348
@@ -380,7 +379,7 @@ def denormalize_target_channels(self, data: torch.tensor) -> torch.tensor:
380379 """
381380 assert data .shape [- 1 ] == len (self .target_idx ), "incorrect number of channels"
382381 for i , ch in enumerate (self .target_idx ):
383- data [..., i ] = (data [..., i ] * self .stdev [ch + 2 ]) + self .mean [ch + 2 ]
382+ data [..., i ] = (data [..., i ] * self .stdev [ch ]) + self .mean [ch ]
384383
385384 return data
386385
0 commit comments