Skip to content

Commit b412493

Browse files
authored
Merge pull request #2 from ritheshkumar95/tristan/gitignore
Add gitignore & Pytorch 0.4 syntax
2 parents 3874acb + 862a2da commit b412493

File tree

4 files changed

+130
-14
lines changed

4 files changed

+130
-14
lines changed

.gitignore

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
env/
12+
build/
13+
develop-eggs/
14+
dist/
15+
downloads/
16+
eggs/
17+
.eggs/
18+
lib/
19+
lib64/
20+
parts/
21+
sdist/
22+
var/
23+
wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
49+
# Translations
50+
*.mo
51+
*.pot
52+
53+
# Django stuff:
54+
*.log
55+
local_settings.py
56+
57+
# Flask stuff:
58+
instance/
59+
.webassets-cache
60+
61+
# Scrapy stuff:
62+
.scrapy
63+
64+
# Sphinx documentation
65+
docs/_build/
66+
67+
# PyBuilder
68+
target/
69+
70+
# Jupyter Notebook
71+
.ipynb_checkpoints
72+
73+
# pyenv
74+
.python-version
75+
76+
# celery beat schedule file
77+
celerybeat-schedule
78+
79+
# SageMath parsed files
80+
*.sage.py
81+
82+
# dotenv
83+
.env
84+
85+
# virtualenv
86+
.venv
87+
venv/
88+
ENV/
89+
90+
# Spyder project settings
91+
.spyderproject
92+
.spyproject
93+
94+
# Rope project settings
95+
.ropeproject
96+
97+
# mkdocs documentation
98+
/site
99+
100+
# mypy
101+
.mypy_cache/
102+
103+
# Temporary
104+
tmp/
105+
run.sh
106+
107+
# Logs & Saves
108+
logs/
109+
saves/
110+
111+
# Slurm
112+
*.out

main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
LAMDA = 1
2020
LR = 3e-4
2121

22+
DEVICE = torch.device('cuda') # torch.device('cpu')
2223

2324
preproc_transform = transforms.Compose([
2425
transforms.ToTensor(),
@@ -39,15 +40,15 @@
3940
num_workers=NUM_WORKERS, pin_memory=True
4041
)
4142

42-
model = AutoEncoder(INPUT_DIM, DIM, K).cuda()
43+
model = AutoEncoder(INPUT_DIM, DIM, K).to(DEVICE)
4344
opt = torch.optim.Adam(model.parameters(), lr=LR)
4445

4546

4647
def train():
4748
train_loss = []
4849
for batch_idx, (x, _) in enumerate(train_loader):
4950
start_time = time.time()
50-
x = x.cuda()
51+
x = x.to(DEVICE)
5152

5253
opt.zero_grad()
5354

@@ -85,7 +86,7 @@ def test():
8586
start_time = time.time()
8687
val_loss = []
8788
for batch_idx, (x, _) in enumerate(test_loader):
88-
x = x.cuda()
89+
x = x.to(DEVICE)
8990
x_tilde, z_e_x, z_q_x = model(x)
9091
loss_recons = F.mse_loss(x_tilde, x)
9192
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
@@ -100,7 +101,7 @@ def test():
100101

101102
def generate_samples():
102103
x, _ = test_loader.__iter__().next()
103-
x = x[:32].cuda()
104+
x = x[:32].to(DEVICE)
104105
x_tilde, _, _ = model(x)
105106

106107
x_cat = torch.cat([x, x_tilde], 0)

modules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def forward(self, x, label):
190190
return self.output_conv(x_h)
191191

192192
def generate(self, label, shape=(8, 8), batch_size=64):
193-
x = torch.zeros(batch_size, *shape).long().cuda()
193+
param = next(self.parameters())
194+
x = torch.zeros((batch_size, *shape),
195+
dtype=torch.int64, device=param.device)
194196

195197
for i in range(shape[0]):
196198
for j in range(shape[1]):

pixelcnn.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
K = 512
2323
LR = 3e-4
2424

25+
DEVICE = torch.device('cuda') # torch.device('cpu')
2526

2627
preproc_transform = transforms.Compose([
2728
transforms.ToTensor(),
@@ -42,23 +43,23 @@
4243
num_workers=NUM_WORKERS, pin_memory=True
4344
)
4445

45-
autoencoder = AutoEncoder(INPUT_DIM, VAE_DIM, K).cuda()
46+
autoencoder = AutoEncoder(INPUT_DIM, VAE_DIM, K).to(DEVICE)
4647
autoencoder.load_state_dict(
4748
torch.load('models/{}_autoencoder.pt'.format(DATASET))
4849
)
4950
autoencoder.eval()
5051

51-
model = GatedPixelCNN(K, DIM, N_LAYERS).cuda()
52-
criterion = nn.CrossEntropyLoss().cuda()
52+
model = GatedPixelCNN(K, DIM, N_LAYERS).to(DEVICE)
53+
criterion = nn.CrossEntropyLoss().to(DEVICE)
5354
opt = torch.optim.Adam(model.parameters(), lr=LR)
5455

5556

5657
def train():
5758
train_loss = []
5859
for batch_idx, (x, label) in enumerate(train_loader):
5960
start_time = time.time()
60-
x = x.cuda()
61-
label = label.cuda()
61+
x = x.to(DEVICE)
62+
label = label.to(DEVICE)
6263

6364
# Get the latent codes for image x
6465
latents, _ = autoencoder.encode(x)
@@ -93,8 +94,8 @@ def test():
9394
val_loss = []
9495
with torch.no_grad():
9596
for batch_idx, (x, label) in enumerate(test_loader):
96-
x = x.cuda()
97-
label = label.cuda()
97+
x = x.to(DEVICE)
98+
label = label.to(DEVICE)
9899

99100
latents, _ = autoencoder.encode(x)
100101
logits = model(latents.detach(), label)
@@ -114,7 +115,7 @@ def test():
114115

115116
def generate_samples():
116117
label = torch.arange(10).expand(10, 10).contiguous().view(-1)
117-
label = label.long().cuda()
118+
label = label.to(device=DEVICE, dtype=torch.int64)
118119

119120
latents = model.generate(label, shape=LATENT_SHAPE, batch_size=100)
120121
x_tilde, _ = autoencoder.decode(latents)
@@ -129,7 +130,7 @@ def generate_samples():
129130

130131
def generate_reconstructions():
131132
x, _ = test_loader.__iter__().next()
132-
x = x[:32].cuda()
133+
x = x[:32].to(DEVICE)
133134

134135
latents, _ = autoencoder.encode(x)
135136
x_tilde, _ = autoencoder.decode(latents)

0 commit comments

Comments
 (0)