Skip to content

Commit f52e199

Browse files
author
morvanzhou
committed
update for new version of torch
1 parent 50d201c commit f52e199

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.idea

tutorial-contents/406_GAN.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,22 @@ def artist_works(): # painting from the famous artist (real target)
5555
plt.ion() # something about continuous plotting
5656

5757
for step in range(10000):
58-
artist_paintings = artist_works() # real painting from artist
59-
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
58+
artist_paintings = artist_works() # real painting from artist
59+
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True) # random ideas\n
6060
G_paintings = G(G_ideas) # fake painting from G (random ideas)
61-
62-
prob_artist0 = D(artist_paintings) # D try to increase this prob
6361
prob_artist1 = D(G_paintings) # D try to reduce this prob
64-
62+
G_loss = torch.mean(torch.log(1. - prob_artist1))
63+
opt_G.zero_grad()
64+
G_loss.backward()
65+
opt_G.step()
66+
67+
prob_artist0 = D(artist_paintings) # D try to increase this prob
68+
prob_artist1 = D(G_paintings.detach()) # D try to reduce this prob
6569
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
66-
G_loss = torch.mean(torch.log(1. - prob_artist1))
67-
6870
opt_D.zero_grad()
6971
D_loss.backward(retain_graph=True) # reusing computational graph
7072
opt_D.step()
7173

72-
opt_G.zero_grad()
73-
G_loss.backward()
74-
opt_G.step()
75-
7674
if step % 50 == 0: # plotting
7775
plt.cla()
7876
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)

0 commit comments

Comments
 (0)