@@ -186,8 +186,8 @@ def __repr__(self):
186
186
return self .__class__ .__name__ + '()'
187
187
188
188
189
- # 데이터셋 불러오는 코드 검증
190
- def show_dataset (images : torch .Tensor , masks : torch .Tensor ):
189
+ # 데이터셋 불러오는 코드 검증 (Shape: [batch, channel, height, width])
190
+ def show_dataset (image : torch .Tensor , target : torch .Tensor ):
191
191
def make_plt_subplot (nrows : int , ncols : int , index : int , title : str , image ):
192
192
plt .subplot (nrows , ncols , index )
193
193
plt .title (title )
@@ -197,8 +197,8 @@ def make_plt_subplot(nrows: int, ncols: int, index: int, title: str, image):
197
197
198
198
to_pil_image = torchvision .transforms .ToPILImage ()
199
199
200
- assert images .shape [0 ] == masks .shape [0 ]
201
- for i in range (images .shape [0 ]):
202
- make_plt_subplot (1 , 2 , 1 , 'Input image' , to_pil_image (images [i ].squeeze ().cpu ()))
203
- make_plt_subplot (1 , 2 , 2 , 'Groundtruth' , to_pil_image (masks [i ].cpu ()))
200
+ assert image .shape [0 ] == target .shape [0 ]
201
+ for i in range (image .shape [0 ]):
202
+ make_plt_subplot (1 , 2 , 1 , 'Input image' , to_pil_image (image [i ].squeeze ().cpu ()))
203
+ make_plt_subplot (1 , 2 , 2 , 'Groundtruth' , to_pil_image (target [i ].cpu ()))
204
204
plt .show ()
0 commit comments