Skip to content

Commit aadc454

Browse files
committed
Fix sampling using multiple filters
1 parent 5ee9835 commit aadc454

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

exllamav2/generator/sampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ def prep_logit_filter(lf):
404404

405405
pt, et = f.get_next()
406406
if len(filters) > 1 and not isinstance(pt, set):
407-
pt, et = set(pt), set(et)
407+
if pt is not None: pt = set(pt)
408+
if et is not None: et = set(et)
408409

409410
if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt
410411
if et is not None: end_tokens = et if end_tokens is None else end_tokens | et
@@ -425,7 +426,7 @@ def prep_logit_filter(lf):
425426
if filter_prefer_eos and tokenizer.eos_token_id in pass_tokens:
426427
pass_tokens_list = [tokenizer.eos_token_id]
427428
logit_filter = prep_logit_filter(logit_filter)
428-
ext_c.logit_filter_exclusive(logit_filter, pass_tokens_list)
429+
ext_c.logit_filter_exclusive(logit_filter, [pass_tokens_list])
429430
else:
430431
logit_filter = prep_logit_filter(logit_filter)
431432
if isinstance(pass_tokens, set):

0 commit comments

Comments
 (0)