@@ -151,11 +151,11 @@ def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16):
151
151
noise = torch .randn (shape , generator = generator , device = device , dtype = dtype )
152
152
return noise
153
153
154
- def encode_image (self , image : torch .Tensor ) -> torch .Tensor :
154
+ def encode_image (
155
+ self , image : torch .Tensor , tiled : bool = False , tile_size : int = 64 , tile_stride : int = 32
156
+ ) -> torch .Tensor :
155
157
image = image .to (self .device , self .vae_encoder .dtype )
156
- latents = self .vae_encoder (
157
- image , tiled = self .vae_tiled , tile_size = self .vae_tile_size , tile_stride = self .vae_tile_stride
158
- )
158
+ latents = self .vae_encoder (image , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )
159
159
return latents
160
160
161
161
def decode_image (self , latent : torch .Tensor ) -> torch .Tensor :
@@ -187,7 +187,7 @@ def prepare_latents(
187
187
self .load_models_to_device (["vae_encoder" ])
188
188
noise = latents
189
189
image = self .preprocess_image (input_image ).to (device = self .device , dtype = self .dtype )
190
- latents = self .encode_image (image , tiled , tile_size , tile_stride )
190
+ latents = self .encode_image (image , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )
191
191
init_latents = latents .clone ()
192
192
latents = self .sampler .add_noise (latents , noise , sigma_start )
193
193
else :
0 commit comments