Skip to content

Commit 92d803f

Browse files
committed
train: Implement train_interupter()
1 parent be3fff4 commit 92d803f

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
prev_miou = 0.0
3737
prev_val_loss = 0.0
3838
for epoch in tqdm.tqdm(range(config[config['model']]['epoch']), desc='Epoch'):
39+
if utils.train_interupter():
40+
print('Train interrupt occurs.')
41+
break
3942
model.train()
4043

4144
for batch_idx, (image, target) in enumerate(tqdm.tqdm(trainloader, desc='Train', leave=False)):

train_interupter.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0

utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,20 @@ def __call__(self, image, target):
187187
return image, target
188188

189189

190+
def train_interupter():
191+
with open('train_interupter.txt', 'r', encoding='utf-8') as f:
192+
flag = f.read()
193+
194+
if flag == '0':
195+
return False
196+
elif flag == '1':
197+
with open('train_interupter.txt', 'w', encoding='utf-8') as f:
198+
f.write('0')
199+
return True
200+
else:
201+
raise ValueError('Wrong flag value.')
202+
203+
190204
# 데이터셋 불러오는 코드 검증 (Shape: [batch, channel, height, width])
191205
def show_dataset(image: torch.Tensor, target: torch.Tensor):
192206
def make_plt_subplot(nrows: int, ncols: int, index: int, title: str, image):

0 commit comments

Comments
 (0)