It seems this code only compute the FID of CIFAR-10,CelebA', 'dots', I am wondering how to compute the fid of the mnist