Skip to content

Commit 1c75f3b

Browse files
committed
Use .to(device) instead of .cuda()
1 parent 60f0e03 commit 1c75f3b

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
LAMDA = 1
2020
LR = 2e-4
2121

22+
DEVICE = torch.device('cuda') # torch.device('cpu')
2223

2324
preproc_transform = transforms.Compose([
2425
transforms.ToTensor(),
@@ -39,15 +40,15 @@
3940
num_workers=NUM_WORKERS, pin_memory=True
4041
)
4142

42-
model = AutoEncoder(INPUT_DIM, DIM, K).cuda()
43+
model = AutoEncoder(INPUT_DIM, DIM, K).to(DEVICE)
4344
opt = torch.optim.Adam(model.parameters(), lr=LR)
4445

4546

4647
def train():
4748
train_loss = []
4849
for batch_idx, (x, _) in enumerate(train_loader):
4950
start_time = time.time()
50-
x = x.cuda()
51+
x = x.to(DEVICE)
5152

5253
opt.zero_grad()
5354

@@ -85,7 +86,7 @@ def test():
8586
start_time = time.time()
8687
val_loss = []
8788
for batch_idx, (x, _) in enumerate(test_loader):
88-
x = x.cuda()
89+
x = x.to(DEVICE)
8990
x_tilde, z_e_x, z_q_x = model(x)
9091
loss_recons = F.mse_loss(x_tilde, x)
9192
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
@@ -100,7 +101,7 @@ def test():
100101

101102
def generate_samples():
102103
x, _ = test_loader.__iter__().next()
103-
x = x[:32].cuda()
104+
x = x[:32].to(DEVICE)
104105
x_tilde, _, _ = model(x)
105106

106107
x_cat = torch.cat([x, x_tilde], 0)

modules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def forward(self, x, label):
190190
return self.output_conv(x_h)
191191

192192
def generate(self, label, shape=(8, 8), batch_size=64):
193-
x = torch.zeros(batch_size, *shape).long().cuda()
193+
param = next(self.parameters())
194+
x = torch.zeros(batch_size, *shape).long()
195+
x = x.to(param.device)
194196

195197
for i in range(shape[0]):
196198
for j in range(shape[1]):

pixelcnn.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
K = 512
2323
LR = 3e-4
2424

25+
DEVICE = torch.device('cuda') # torch.device('cpu')
2526

2627
preproc_transform = transforms.Compose([
2728
transforms.ToTensor(),
@@ -42,23 +43,23 @@
4243
num_workers=NUM_WORKERS, pin_memory=True
4344
)
4445

45-
autoencoder = AutoEncoder(INPUT_DIM, VAE_DIM, K).cuda()
46+
autoencoder = AutoEncoder(INPUT_DIM, VAE_DIM, K).to(DEVICE)
4647
autoencoder.load_state_dict(
4748
torch.load('models/{}_autoencoder.pt'.format(DATASET))
4849
)
4950
autoencoder.eval()
5051

51-
model = GatedPixelCNN(K, DIM, N_LAYERS).cuda()
52-
criterion = nn.CrossEntropyLoss().cuda()
52+
model = GatedPixelCNN(K, DIM, N_LAYERS).to(DEVICE)
53+
criterion = nn.CrossEntropyLoss().to(DEVICE)
5354
opt = torch.optim.Adam(model.parameters(), lr=LR)
5455

5556

5657
def train():
5758
train_loss = []
5859
for batch_idx, (x, label) in enumerate(train_loader):
5960
start_time = time.time()
60-
x = x.cuda()
61-
label = label.cuda()
61+
x = x.to(DEVICE)
62+
label = label.to(DEVICE)
6263

6364
# Get the latent codes for image x
6465
latents, _ = autoencoder.encode(x)
@@ -93,8 +94,8 @@ def test():
9394
val_loss = []
9495
with torch.no_grad():
9596
for batch_idx, (x, label) in enumerate(test_loader):
96-
x = x.cuda()
97-
label = label.cuda()
97+
x = x.to(DEVICE)
98+
label = label.to(DEVICE)
9899

99100
latents, _ = autoencoder.encode(x)
100101
logits = model(latents.detach(), label)
@@ -114,7 +115,7 @@ def test():
114115

115116
def generate_samples():
116117
label = torch.arange(10).expand(10, 10).contiguous().view(-1)
117-
label = label.long().cuda()
118+
label = label.long().to(DEVICE)
118119

119120
latents = model.generate(label, shape=LATENT_SHAPE, batch_size=100)
120121
x_tilde, _ = autoencoder.decode(latents)
@@ -129,7 +130,7 @@ def generate_samples():
129130

130131
def generate_reconstructions():
131132
x, _ = test_loader.__iter__().next()
132-
x = x[:32].cuda()
133+
x = x[:32].to(DEVICE)
133134

134135
latents, _ = autoencoder.encode(x)
135136
x_tilde, _ = autoencoder.decode(latents)

0 commit comments

Comments
 (0)