@@ -23,10 +23,10 @@ def setup(self, stage=None):
23
23
self .valid_ds = TensorDataset (X_valid , y_valid )
24
24
25
25
def train_dataloader (self ):
26
- return DataLoader (self .train_ds , batch_size = self .batch_size , shuffle = True )
26
+ return DataLoader (self .train_ds , batch_size = self .batch_size , shuffle = True , num_workers = 1 )
27
27
28
28
def val_dataloader (self ):
29
- return DataLoader (self .valid_ds , batch_size = self .batch_size , shuffle = False )
29
+ return DataLoader (self .valid_ds , batch_size = self .batch_size , shuffle = False , num_workers = 1 )
30
30
31
31
32
32
class LitClassifier (pl .LightningModule ):
@@ -54,7 +54,6 @@ def validation_step(self, batch, batch_idx):
54
54
def configure_optimizers (self ):
55
55
return torch .optim .Adam (self .parameters (), lr = 1e-2 )
56
56
57
-
58
57
class TestPytorchLightning (unittest .TestCase ):
59
58
60
59
def test_version (self ):
@@ -64,5 +63,8 @@ def test_mnist(self):
64
63
dm = LitDataModule ()
65
64
model = LitClassifier ()
66
65
trainer = pl .Trainer (gpus = None , max_epochs = 1 )
67
- result = trainer .fit (model , datamodule = dm )
68
- self .assertTrue (result )
66
+ trainer .fit (model , datamodule = dm )
67
+
68
+ self .assertIn ("train_loss" , trainer .logged_metrics )
69
+ self .assertIn ("val_loss" , trainer .logged_metrics )
70
+
0 commit comments