Skip to content

Commit e3875a5

Browse files
authored
Add device parameter to overridden _build_causal_attention_mask (#184)
1 parent ea904e3 commit e3875a5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ def convert_text_encoder(pipe, args):
283283
}
284284
logger.info(f"Sample inputs spec: {sample_text_encoder_inputs_spec}")
285285

286-
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
287-
mask = torch.ones((bsz, seq_len, seq_len), dtype=dtype) * -1e4
286+
def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
287+
mask = torch.ones((bsz, seq_len, seq_len), dtype=dtype, device=device) * -1e4
288288
mask.triu_(1)
289289
mask = mask.unsqueeze(1)
290290
return mask

0 commit comments

Comments
 (0)