Skip to content

Commit 862a2da

Browse files
committed
Replace .long() by dtypes
1 parent 1c75f3b commit 862a2da

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def forward(self, x, label):
191191

192192
def generate(self, label, shape=(8, 8), batch_size=64):
193193
param = next(self.parameters())
194-
x = torch.zeros(batch_size, *shape).long()
195-
x = x.to(param.device)
194+
x = torch.zeros((batch_size, *shape),
195+
dtype=torch.int64, device=param.device)
196196

197197
for i in range(shape[0]):
198198
for j in range(shape[1]):

pixelcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test():
115115

116116
def generate_samples():
117117
label = torch.arange(10).expand(10, 10).contiguous().view(-1)
118-
label = label.long().to(DEVICE)
118+
label = label.to(device=DEVICE, dtype=torch.int64)
119119

120120
latents = model.generate(label, shape=LATENT_SHAPE, batch_size=100)
121121
x_tilde, _ = autoencoder.decode(latents)

0 commit comments

Comments
 (0)