Skip to content

Commit 81f8a38

Browse files
tenderness-gitsir1st-inc
authored andcommitted
bug fix
1 parent c833b6e commit 81f8a38

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

diffsynth_engine/pipelines/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,11 @@ def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16):
151151
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
152152
return noise
153153

154-
def encode_image(self, image: torch.Tensor) -> torch.Tensor:
154+
def encode_image(
155+
self, image: torch.Tensor, tiled: bool = False, tile_size: int = 64, tile_stride: int = 32
156+
) -> torch.Tensor:
155157
image = image.to(self.device, self.vae_encoder.dtype)
156-
latents = self.vae_encoder(
157-
image, tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
158-
)
158+
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
159159
return latents
160160

161161
def decode_image(self, latent: torch.Tensor) -> torch.Tensor:
@@ -187,7 +187,7 @@ def prepare_latents(
187187
self.load_models_to_device(["vae_encoder"])
188188
noise = latents
189189
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.dtype)
190-
latents = self.encode_image(image, tiled, tile_size, tile_stride)
190+
latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
191191
init_latents = latents.clone()
192192
latents = self.sampler.add_noise(latents, noise, sigma_start)
193193
else:

0 commit comments

Comments
 (0)