Skip to content

Commit 0afb904

Browse files
committed
adding torch plot sample
1 parent b5b6895 commit 0afb904

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

mitdeeplearning/util.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,46 @@ def display_model(model):
1212
return ipythondisplay.Image("tmp.png")
1313

1414

15-
def plot_sample(x, y, vae):
15+
def plot_sample(x, y, vae, backend='tf'):
16+
"""Plot original and reconstructed images side by side.
17+
18+
Args:
19+
x: Input images array of shape [B, H, W, C] (TF) or [B, C, H, W] (PT)
20+
y: Labels array of shape [B] where 1 indicates a face
21+
vae: VAE model (TensorFlow or PyTorch)
22+
framework: 'tf' or 'pt' indicating which framework to use
23+
"""
1624
plt.figure(figsize=(2, 1))
17-
plt.subplot(1, 2, 1)
1825

19-
idx = np.where(y == 1)[0][0]
26+
if backend == 'tf':
27+
idx = np.where(y == 1)[0][0]
28+
_, _, _, recon = vae(x)
29+
recon = np.clip(recon, 0, 1)
30+
31+
elif backend == 'pt':
32+
y = y.detach().cpu().numpy()
33+
face_indices = np.where(y == 1)[0]
34+
idx = face_indices[0] if len(face_indices) > 0 else 0
35+
36+
with torch.inference_mode():
37+
_, _, _, recon = vae(x)
38+
recon = torch.clamp(recon, 0, 1)
39+
recon = recon.permute(0, 2, 3, 1).detach().cpu().numpy()
40+
x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
41+
42+
else:
43+
raise ValueError("framework must be 'tf' or 'pt'")
44+
45+
plt.subplot(1, 2, 1)
2046
plt.imshow(x[idx])
2147
plt.grid(False)
2248

23-
plt.subplot(1, 2, 2)
24-
_, _, _, recon = vae(x)
25-
recon = np.clip(recon, 0, 1)
49+
plt.subplot(1, 2, 2)
2650
plt.imshow(recon[idx])
2751
plt.grid(False)
2852

29-
# plt.show()
53+
if backend == 'pt':
54+
plt.show()
3055

3156

3257
class LossHistory:

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ def get_dist(pkgname):
2222
setup(
2323
name = 'mitdeeplearning', # How you named your package folder (MyLib)
2424
packages = ['mitdeeplearning'], # Chose the same as "name"
25-
version = '0.6.1', # Start with a small number and increase it with every change you make
25+
version = '0.7.2', # Start with a small number and increase it with every change you make
2626
license='MIT', # Chose a license from here: https://help.github.com/articles/licensing-a-repository
2727
description = 'Official software labs for MIT Introduction to Deep Learning (http://introtodeeplearning.com)', # Give a short description about your library
2828
author = 'Alexander Amini', # Type in your name
2929
author_email = 'introtodeeplearning-staff@mit.edu', # Type in your E-Mail
3030
url = 'http://introtodeeplearning.com', # Provide either the link to your github or to your website
31-
download_url = 'https://github.com/aamini/introtodeeplearning/archive/v0.6.1.tar.gz', # I explain this later on
31+
download_url = 'https://github.com/aamini/introtodeeplearning/archive/v0.7.2.tar.gz', # I explain this later on
3232
keywords = ['deep learning', 'neural networks', 'tensorflow', 'introduction'], # Keywords that define your package best
3333
install_requires=install_deps,
3434
classifiers=[

0 commit comments

Comments
 (0)