-
Notifications
You must be signed in to change notification settings - Fork 567
Description
Hey guys, I'm having an issue with deprecated BetterTransformers, my models are slowed up 2-3x times on CPU.
System Info
I am using Python 3.11
Increased lib versions
torch 1.13 -> 2.7
transformers 4.24 -> 4.48
Information
I was using code like this:
BetterTransformer.transform(self.model.bert, keep_original_model=False)
with absolutely basic fine-tuned bert on CPU
Now I'm trying to do (Added attn_implementation='sdpa'):
config = AutoConfig.from_pretrained(bert_config_path, local_files_only=True, attn_implementation='sdpa')
self.bert = BertModel(config=config)
And I got slowing down from 2-3sec to around 4-6sec.
There are 3 bert-based models in my service and each of them is contributing to slowing down.
One of the models (NER) is initiated like this:
cls._instance = pipeline(
'token-classification',
model=model_path,
aggregation_strategy='average',
device=device,
accelerator='bettertransformer',
)
I got it, that deprecation is inevitable, but maybe you can give me some advices on what I'm missing with increasing speed of model?
Now I'm using only attn_implementation='sdpa'
, which is anyway set by default according to docs, and it didn't gave any speed :(
Maybe I should do smth else? Flash attention backends are not cool for CPU as I know