Skip to content

Commit 26e37eb

Browse files
Fix some issues in DA training
1 parent 05d3555 commit 26e37eb

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

synaptic_reconstruction/training/semisupervised_training.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,26 @@ def get_unsupervised_loader(
6161
else:
6262
roi = None
6363

64-
z,y,x = patch_shape
65-
ndim = 2 if z == 1 else 3
64+
if len(patch_shape) == 2:
65+
ndim = 2
66+
else:
67+
assert len(patch_shape) == 3
68+
z, y, x = patch_shape
69+
ndim = 2 if z == 1 else 3
6670
print("ndim is: ", ndim)
6771

6872
raw_transform = torch_em.transform.get_raw_transform()
6973
transform = torch_em.transform.get_augmentations(ndim=ndim)
7074

75+
if n_samples is None:
76+
n_samples_per_ds = None
77+
else:
78+
n_samples_per_ds = int(n_samples / len(data_paths))
79+
7180
augmentations = (weak_augmentations(), weak_augmentations())
7281
datasets = [
7382
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
74-
augmentations=augmentations, roi=roi, ndim = ndim)
83+
augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds)
7584
for path in data_paths
7685
]
7786
ds = torch.utils.data.ConcatDataset(datasets)
@@ -136,9 +145,9 @@ def semisupervised_training(
136145
# check_loader(val_loader, n_samples=4)
137146
return
138147

139-
#check for 2D or 3D training
148+
# Check for 2D or 3D training
140149
is_2d = False
141-
z,y,x = patch_shape
150+
z, y, x = patch_shape
142151
is_2d = z == 1
143152

144153
if is_2d:

0 commit comments

Comments
 (0)