Training StyleGAN2 in Jax/Flax #7905
Unanswered
matthias-wright
asked this question in
Show and tell
Replies: 1 comment
-
🔥 @levskaya @andsteing @marcvanzee @jheek @gnecula @avital @mattjj |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
Here is an implementation of StyleGAN2 in Jax/Flax that is compatible with the official weights. The training code can be found here.
Highlights
jax.pmap
tf.data.TFRecordDataset
is used for efficient data loadingLowlights
flax.optim.DynamicScale
and also all the tricks from the original implementation (casting to float32 for some operations, using pre-normalization in the modulated conv layer, only using float16 for the higher resolutions, clipping the output of the convolution layers, etc) but it is not properly working yet.I also added instructions on how to train on your own dataset. If you have any questions feel free to ask them here or on the github page.
Cheers!
Some results

Style mixing

Beta Was this translation helpful? Give feedback.
All reactions