You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This repository provide direct and simple access to the pretrained "deep" versions of BigGAN for 128, 256 and 512 pixels resolutions as described in the [associated publication](https://openreview.net/forum?id=B1xsqj09Fm).
33
33
Here are some details on the models:
34
34
35
-
-`BigGAN-deep-128` is a 50.4M parameters model generating 128x128 pixels images, the model dump weights 201 MB,
36
-
-`BigGAN-deep-256` is a 55.9M parameters model generating 256x256 pixels images, the model dump weights 224 MB,
37
-
-`BigGAN-deep-512` is a 56.2M parameters model generating 512x512 pixels images, the model dump weights 225 MB.
35
+
-`BigGAN-deep-128`: a 50.4M parameters model generating 128x128 pixels images, the model dump weights 201 MB,
36
+
-`BigGAN-deep-256`: a 55.9M parameters model generating 256x256 pixels images, the model dump weights 224 MB,
37
+
-`BigGAN-deep-512`: a 56.2M parameters model generating 512x512 pixels images, the model dump weights 225 MB.
38
38
39
-
Please refere to Appendix B of the paper for details on the architectures.
39
+
Please refer to Appendix B of the paper for details on the architectures.
40
40
41
41
All models comprise pre-computed batch norm statistics for 51 truncation values between 0 and 1 (see Appendix C.1 in the paper for details).
42
42
43
43
## Usage
44
44
45
45
Here is a quick-start example using `BigGAN` with a pre-trained model.
46
-
See the [doc section](#doc) below for all the details on these classes.
46
+
47
+
See the [doc section](#doc) below for details on these classes and methods.
47
48
48
49
```python
49
50
import torch
50
-
from pytorch_pretrained_biggan import BigGAN, one_hot_from_name, truncated_noise_sample, save_as_images, display_in_terminal
51
+
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_name, truncated_noise_sample,
52
+
save_as_images, display_in_terminal)
51
53
52
54
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
# If you have a sixtel compatible terminal you can display the images in the terminal (see https://github.com/saitoha/libsixel)
76
+
# If you have a sixtel compatible terminal you can display the images in the terminal
77
+
# (see https://github.com/saitoha/libsixel for details)
75
78
display_in_terminal(dogball)
76
79
```
77
80
78
81
## Doc
79
82
80
-
### Loading DeepMind's pre-trained weigths
83
+
### Loading DeepMind's pre-trained weights
81
84
82
-
To load one of DeepMind's pre-trained models, instantiate an instance of `BigGAN` as
85
+
To load one of DeepMind's pre-trained models, instantiate a `BigGAN` model with `from_pretrained()` as:
83
86
84
87
```python
85
88
model = BigGAN.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
@@ -105,9 +108,9 @@ where
105
108
106
109
### Configuration
107
110
108
-
`BigGANConfig` is the BigGAN configuration class stored in [`config.py`](./pytorch_pretrained_biggan/config.py).
111
+
`BigGANConfig` is a class to store and load BigGAN configurations. It's defined in [`config.py`](./pytorch_pretrained_biggan/config.py).
109
112
110
-
Here are the details of the attributes:
113
+
Here are some details on the attributes:
111
114
112
115
-`output_dim`: output resolution of the GAN (128, 256 or 512) for the pre-trained models,
113
116
-`z_dim`: size of the noise vector (128 for the pre-trained models).
@@ -121,25 +124,29 @@ Here are the details of the attributes:
121
124
122
125
### Model
123
126
124
-
`BigGAN` is the BigGAN model. It comprises the class embeddings linear layer and the generator. The discrimiantor is currently not provided since pre-trained weights have not been released.
127
+
`BigGAN` is a PyTorch model (`torch.nn.Module`) of BigGAN defined in [`model.py`](./pytorch_pretrained_biggan/model.py). This model comprises the class embeddings (a linear layer) and the generator with a series of convolutions and conditional batch norms. The discriminator is currently not implemented since pre-trained weights have not been released for it.
125
128
126
129
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
127
130
128
-
We detail them here. This model takes as *inputs*:
-`z`: a torch.FloatTensor of shape [batch_size, config.z_dim] with noise sampled from a truncated normal distribution, and
131
136
-`class_label`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
132
137
-`truncation`: a float between 0 (not comprised) and 1. The truncation of the truncated normal used for creating the noise vector. This truncation value is used to selecte between a set of pre-computed statistics (means and variances) for the batch norm layers.
133
138
134
-
This model*outputs* an array of shape [batch_size, 3, resolution, resolution] where resolution is 128, 256 or 512 depending of the model:
139
+
`BigGAN`*outputs* an array of shape [batch_size, 3, resolution, resolution] where resolution is 128, 256 or 512 depending of the model:
135
140
136
141
### Utilities: Images, Noise, Imagenet classes
137
142
138
-
We provide a few utility method to use the model in [`utils.py`](./pytorch_pretrained_biggan/utils.py).
143
+
We provide a few utility method to use the model. They are defined in [`utils.py`](./pytorch_pretrained_biggan/utils.py).
139
144
140
145
Here are some details on these methods:
141
146
142
-
-`truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None)`: Create a truncated noise vector.
@@ -148,25 +155,33 @@ Here are some details on these methods:
148
155
- Output:
149
156
array of shape (batch_size, dim_z)
150
157
151
-
-`convert_to_images(obj)`: Convert an output tensor from BigGAN in a list of images.
158
+
-`convert_to_images(obj)`:
159
+
160
+
Convert an output tensor from BigGAN in a list of images.
152
161
- Params:
153
162
- obj: tensor or numpy array of shape (batch_size, channels, height, width)
154
163
- Output:
155
164
- list of Pillow Images of size (height, width)
156
165
157
-
-`save_as_images(obj, file_name='output')`: Convert and save an output tensor from BigGAN in a list of saved images.
166
+
-`save_as_images(obj, file_name='output')`:
167
+
168
+
Convert and save an output tensor from BigGAN in a list of saved images.
158
169
- Params:
159
170
- obj: tensor or numpy array of shape (batch_size, channels, height, width)
160
171
- file_name: path and beggingin of filename to save.
161
172
Images will be saved as `file_name_{image_number}.png`
162
173
163
-
-`display_in_terminal(obj)`: Convert and display an output tensor from BigGAN in the terminal. This function use `libsixel` and will only work in a libsixel-compatible terminal. Please refer to https://github.com/saitoha/libsixel for more details.
174
+
-`display_in_terminal(obj)`:
175
+
176
+
Convert and display an output tensor from BigGAN in the terminal. This function use `libsixel` and will only work in a libsixel-compatible terminal. Please refer to https://github.com/saitoha/libsixel for more details.
164
177
- Params:
165
178
- obj: tensor or numpy array of shape (batch_size, channels, height, width)
166
179
- file_name: path and beggingin of filename to save.
167
180
Images will be saved as `file_name_{image_number}.png`
168
181
169
-
-`one_hot_from_int(int_or_list, batch_size=1)`: Create a one-hot vector from a class index or a list of class indices.
182
+
-`one_hot_from_int(int_or_list, batch_size=1)`:
183
+
184
+
Create a one-hot vector from a class index or a list of class indices.
170
185
- Params:
171
186
- int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
172
187
- batch_size: batch size.
@@ -175,17 +190,20 @@ Here are some details on these methods:
175
190
- Output:
176
191
- array of shape (batch_size, 1000)
177
192
178
-
-`one_hot_from_name(class_name, batch_size=1)`: Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...). We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one. If we can't find it direcly, we look at the hyponyms and hypernyms of the class name.
193
+
-`one_hot_from_name(class_name, batch_size=1)`:
194
+
195
+
Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...). We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one. If we can't find it direcly, we look at the hyponyms and hypernyms of the class name.
179
196
- Params:
180
197
- class_name: string containing the name of an imagenet object.
181
198
- Output:
182
199
- array of shape (batch_size, 1000)
183
200
184
-
## Conversion script
201
+
## Download and conversion scripts
185
202
186
-
A script that can be used to convert models from TensorFlow Hub is provided in [./scripts/convert_tf_hub_models.sh](./scripts/convert_tf_hub_models.sh).
203
+
Scripts to download and convert the TensorFlow models from TensorFlow Hub are provided in [./scripts](./scripts/).
0 commit comments