Skip to content

Add support for lite-whisper #1886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 63 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
ba51314
testing spec
Eyoel-gebre Apr 9, 2025
63cb8ff
testing spec
Eyoel-gebre Apr 10, 2025
24fe394
added test
Eyoel-gebre Apr 10, 2025
d258577
update
Eyoel-gebre Apr 10, 2025
607a649
fix
Eyoel-gebre Apr 10, 2025
cf840ad
loading fixes
Eyoel-gebre Apr 11, 2025
0412e6f
fix
Eyoel-gebre Apr 14, 2025
7dbd2d8
fixed
Eyoel-gebre Apr 14, 2025
a5ac7cd
fixed
Eyoel-gebre Apr 14, 2025
a667e1e
test
Eyoel-gebre Apr 15, 2025
9220303
test
Eyoel-gebre Apr 15, 2025
80e6b7b
test
Eyoel-gebre Apr 15, 2025
e501642
test
Eyoel-gebre Apr 15, 2025
e17f23f
test
Eyoel-gebre Apr 15, 2025
eb83092
test
Eyoel-gebre Apr 15, 2025
f786c6d
test_axis_1
Eyoel-gebre Apr 17, 2025
e7cfdad
test_axis_1
Eyoel-gebre Apr 17, 2025
a6a4c42
test_axis_1
Eyoel-gebre Apr 17, 2025
ff9a3af
test_axis_1
Eyoel-gebre Apr 17, 2025
b8d0c7a
test_axis_1
Eyoel-gebre Apr 17, 2025
566e35b
test_axis_1
Eyoel-gebre Apr 17, 2025
5230ee4
test_axis_1
Eyoel-gebre Apr 17, 2025
220ae07
test_axis_1
Eyoel-gebre Apr 17, 2025
5259c69
test_axis_1
Eyoel-gebre Apr 17, 2025
c016c41
test_axis_1
Eyoel-gebre Apr 17, 2025
bfd3386
test_axis_1
Eyoel-gebre Apr 17, 2025
58d2b5a
logging
Eyoel-gebre Apr 21, 2025
2b0fbf1
logging
Eyoel-gebre Apr 21, 2025
4aa77d8
logging
Eyoel-gebre Apr 21, 2025
84f54d4
logging
Eyoel-gebre Apr 21, 2025
e08a06e
logging
Eyoel-gebre Apr 21, 2025
ab31afc
logging
Eyoel-gebre Apr 21, 2025
d6b924b
logging
Eyoel-gebre Apr 21, 2025
58c7547
logging
Eyoel-gebre Apr 21, 2025
94e6b34
logging
Eyoel-gebre Apr 21, 2025
97aeeeb
logging
Eyoel-gebre Apr 21, 2025
b113821
logging
Eyoel-gebre Apr 21, 2025
77837b9
logging
Eyoel-gebre Apr 21, 2025
98e8941
logging
Eyoel-gebre Apr 21, 2025
b6c8edc
logging
Eyoel-gebre Apr 21, 2025
c1c5c47
spec
Eyoel-gebre Apr 21, 2025
99ab963
remove-test
Eyoel-gebre Apr 22, 2025
367b6c6
remove-garbage
Eyoel-gebre Apr 22, 2025
d7dc107
remove-garbage
Eyoel-gebre Apr 22, 2025
93e4a70
updated model executor for lite-whisper
Eyoel-gebre Apr 30, 2025
a29eef3
debugging
Eyoel-gebre Apr 30, 2025
0b5fa40
debugging
Eyoel-gebre Apr 30, 2025
7266663
debugging
Eyoel-gebre Apr 30, 2025
dacb949
debugging
Eyoel-gebre Apr 30, 2025
ee0527c
debugging
Eyoel-gebre Apr 30, 2025
5cf90b9
debugging
Eyoel-gebre Apr 30, 2025
3098c95
naming fix
Eyoel-gebre May 1, 2025
5caaee5
.
Eyoel-gebre May 2, 2025
23df460
shapes
Eyoel-gebre May 2, 2025
6bdce85
shapes2
Eyoel-gebre May 2, 2025
2821169
shape3
Eyoel-gebre May 2, 2025
b4b35b3
preprocessor
Eyoel-gebre May 3, 2025
f2b7b22
small
Eyoel-gebre May 3, 2025
af19fb4
small
Eyoel-gebre May 3, 2025
84e346a
dims
Eyoel-gebre May 4, 2025
6ac5325
tweaks
Eyoel-gebre May 4, 2025
e6a83fc
error handling
Eyoel-gebre May 4, 2025
76c0bb4
minor
Eyoel-gebre May 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
*.pyc
/.vs
.vscode
/build

CMake*.json
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace ctranslate2 {
const bool multi_query = false);

protected:
bool _is_low_rank;
const bool _tensor_parallel;
const dim_t _num_heads;
const bool _self_attention;
Expand Down
2 changes: 2 additions & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ namespace ctranslate2 {
void select_weights(const StorageView* index, const StorageView* extra_bias = nullptr);
private:
bool _packed_weight;
bool _is_low_rank;
const StorageView& _weight;
const StorageView* _weight2;
const StorageView* _bias;
const StorageView* _qscale;
const StorageView* _qzero;
Expand Down
98 changes: 95 additions & 3 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc
import itertools
import os
import re

from typing import List, Optional

Expand Down Expand Up @@ -96,6 +97,13 @@ def __init__(
trust_remote_code: Allow converting models using custom code.
"""
self._model_name_or_path = model_name_or_path
self._model_processor_name = model_name_or_path
if model_name_or_path.startswith('efficient-speech/lite-whisper'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the best approach, since a different org might upload their own model and then this condition will evaluate to false

# If this is a lite-whisper model, use openai's
# corresponding preprocessor.
regex = r'whisper-[a-z0-9-]+?(?=-(?:fast|acc)|$)'
regex_result = re.search(regex, model_name_or_path)
self._model_processor_name = f"openai/{regex_result.group()}"
self._activation_scales = activation_scales
self._copy_files = copy_files
self._load_as_float16 = load_as_float16
Expand All @@ -119,7 +127,6 @@ def _load(self):
% (config_name, ", ".join(sorted(_MODEL_LOADERS.keys())))
)

model_class = getattr(transformers, loader.architecture_name)
tokenizer_class = transformers.AutoTokenizer

kwargs = {
Expand All @@ -137,14 +144,21 @@ def _load(self):
if self._trust_remote_code:
kwargs["trust_remote_code"] = self._trust_remote_code

model = self.load_model(model_class, self._model_name_or_path, **kwargs)
if hasattr(transformers, loader.architecture_name):
model_class = getattr(transformers, loader.architecture_name)
model = self.load_model(model_class, self._model_name_or_path, **kwargs)
elif self._model_name_or_path.startswith('efficient-speech/lite-whisper'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

model = transformers.AutoModel.from_pretrained(self._model_name_or_path, **kwargs)
else:
raise ValueError(
"The model %s is not supported by the converter. " % self._model_name_or_path)

tokenizer_kwargs = {}
if self._trust_remote_code:
tokenizer_kwargs["trust_remote_code"] = self._trust_remote_code

tokenizer = self.load_tokenizer(
tokenizer_class, self._model_name_or_path, **tokenizer_kwargs
tokenizer_class, self._model_processor_name, **tokenizer_kwargs
)

spec = loader(model, tokenizer)
Expand Down Expand Up @@ -237,6 +251,19 @@ def set_linear(self, spec, module, quant_type=common_spec.Quantization.CT2):
spec.weight = spec.weight.transpose(0, 1)
if module.bias is not None:
spec.bias = module.bias

def set_low_rank_linear(self, spec, module, quant_type=common_spec.Quantization.CT2):
if quant_type == common_spec.Quantization.CT2:
spec.low_rank_weight_1 = module.weight1
spec.low_rank_weight_2 = module.weight2
else:
spec.low_rank_weight_1 = module.qweight1
spec.low_rank_weight_2 = module.qweight2
spec.weight_scale = module.scales
spec.weight_zero = module.qzeros

if module.bias is not None:
spec.bias = module.bias

def set_embeddings(self, spec, module):
spec.weight = module.weight
Expand Down Expand Up @@ -996,6 +1023,71 @@ def set_conv1d(self, spec, module):
spec.weight = module.weight
spec.bias = module.bias

@register_loader("LiteWhisperConfig")
class LiteWhisperLoader(WhisperLoader):
@property
def architecture_name(self):
return "LiteWhisperForConditionalGeneration"

def get_model_spec(self, model):
spec = whisper_spec.WhisperSpec(
model.config.encoder_layers,
model.config.encoder_attention_heads,
model.config.decoder_layers,
model.config.decoder_attention_heads,
low_rank=True,
)

self.set_encoder(spec.encoder, model.model.encoder)
self.set_decoder(spec.decoder, model.model.decoder)
self.set_linear(spec.decoder.projection, model.proj_out)

return spec

def set_encoder(self, spec, encoder):
self.set_conv1d(spec.conv1, encoder.conv1)
self.set_conv1d(spec.conv2, encoder.conv2)

self.set_common_layers(spec, encoder)

for layer_spec, layer in zip(spec.layer, encoder.layers):
self.set_low_rank_attention(
layer_spec.self_attention,
layer.self_attn,
)
self.set_layer_norm(
layer_spec.self_attention.layer_norm,
layer.self_attn_layer_norm,
)

# Double check if these are low rank or not because of potential
# fall backs to full precision.
if hasattr(layer.fc1, 'weight1'):
self.set_low_rank_linear(layer_spec.ffn.linear_0, layer.fc1)
else:
layer_spec.ffn.linear_0 = common_spec.LinearSpec()
self.set_linear(layer_spec.ffn.linear_0, layer.fc1)

if hasattr(layer.fc2, 'weight1'):
self.set_low_rank_linear(layer_spec.ffn.linear_1, layer.fc2)
else:
layer_spec.ffn.linear_1 = common_spec.LinearSpec()
self.set_linear(layer_spec.ffn.linear_1, layer.fc2)

self.set_layer_norm(layer_spec.ffn.layer_norm, layer.final_layer_norm)

def set_low_rank_or_linear_router(self, spec, module, i):
if hasattr(module, "weight1"):
self.set_low_rank_linear(spec.linear[i], module)
else:
spec.linear[i] = common_spec.LinearSpec()
self.set_linear(spec.linear[i], module)

def set_low_rank_attention(self, spec, attention):
self.set_low_rank_or_linear_router(spec, attention.q_proj, 0)
self.set_low_rank_or_linear_router(spec, attention.k_proj, 1)
self.set_low_rank_or_linear_router(spec, attention.v_proj, 2)
self.set_low_rank_or_linear_router(spec, attention.out_proj, 3)

@register_loader("Wav2Vec2Config")
class Wav2Vec2Loader(BartLoader):
Expand Down
10 changes: 7 additions & 3 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ def __init__(
num_heads_kv=None,
head_dim=None,
sliding_window=None,
low_rank=False,
):
self.queries_scale = model_spec.OPTIONAL

self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
self.linear = [
common_spec.LinearSpec() for _ in range(2 if self_attention else 3)
]
if low_rank:
self.linear = [common_spec.LowRankLinearSpec() for _ in range(4)]
else:
self.linear = [
common_spec.LinearSpec() for _ in range(2 if self_attention else 3)
]

if relative_position:
self.relative_position_keys = None
Expand Down
12 changes: 12 additions & 0 deletions python/ctranslate2/specs/common_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def __init__(self):
def has_bias(self):
return not isinstance(self.bias, str)

class LowRankLinearSpec(model_spec.LayerSpec):
def __init__(self):
super().__init__()
self.low_rank_weight_1 = None
self.low_rank_weight_2 = None
self.weight_scale = model_spec.OPTIONAL
self.weight_zero = model_spec.OPTIONAL
self.bias = model_spec.OPTIONAL

def has_bias(self):
return not isinstance(self.bias, str)


class Conv1DSpec(model_spec.LayerSpec):
def __init__(self):
Expand Down
10 changes: 6 additions & 4 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def __init__(
rms_norm=False,
num_heads_kv=None,
sliding_window=None,
low_rank=False
):
self.self_attention = attention_spec.MultiHeadAttentionSpec(
self_attention=True,
Expand All @@ -261,8 +262,9 @@ def __init__(
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
low_rank=low_rank,
)
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm, low_rank=low_rank)


class TransformerDecoderLayerSpec(model_spec.LayerSpec):
Expand Down Expand Up @@ -340,10 +342,10 @@ def __init__(


class FeedForwardSpec(model_spec.LayerSpec):
def __init__(self, glu=False, rms_norm=False):
def __init__(self, glu=False, rms_norm=False, low_rank=False):
self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
self.linear_0 = common_spec.LinearSpec()
self.linear_1 = common_spec.LinearSpec()
self.linear_0 = common_spec.LinearSpec() if not low_rank else common_spec.LowRankLinearSpec()
self.linear_1 = common_spec.LinearSpec() if not low_rank else common_spec.LowRankLinearSpec()
if glu:
self.linear_0_noact = common_spec.LinearSpec()

Expand Down
7 changes: 4 additions & 3 deletions python/ctranslate2/specs/whisper_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
num_encoder_heads,
num_decoder_layers,
num_decoder_heads,
low_rank=False,
):
"""Initializes the model specification.

Expand All @@ -42,7 +43,7 @@ def __init__(
num_decoder_heads: The number of decoder attention heads.
"""
super().__init__()
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads)
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads, low_rank)
self.decoder = transformer_spec.TransformerDecoderSpec(
num_decoder_layers,
num_decoder_heads,
Expand All @@ -66,12 +67,12 @@ def get_vocabulary_size(self):


class WhisperEncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers, num_heads):
def __init__(self, num_layers, num_heads, low_rank=False):
self.num_heads = np.dtype("int16").type(num_heads)
self.conv1 = common_spec.Conv1DSpec()
self.conv2 = common_spec.Conv1DSpec()
self.position_encodings = transformer_spec.PositionEncoderSpec()
self.layer_norm = common_spec.LayerNormSpec()
self.layer = [
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
transformer_spec.TransformerEncoderLayerSpec(low_rank=low_rank) for _ in range(num_layers)
]
23 changes: 20 additions & 3 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,22 @@ namespace ctranslate2 {
q = &queries_proj;
}

_linear[0](*q, fused_proj);
if (!_is_low_rank) {
_linear[0](*q, fused_proj);
} else {
// Low-rank attention does not fuse qkv.
_linear[0](*q, queries_proj);
_linear[1](*q, keys_proj);
_linear[2](*q, values_proj);
}

dim_t beam_size = 1;

bool prefilling = (_sliding_window > 0 && values_lengths);

if (!_self_attention) {
if (_is_low_rank)
throw std::invalid_argument("MultiHeadAttention does not support low-rank attention with cross-attention");
queries_proj = std::move(fused_proj);

if (cached_keys == nullptr || cached_keys->empty()) {
Expand Down Expand Up @@ -401,6 +410,8 @@ namespace ctranslate2 {
} else {

if (_num_heads_kv < _num_heads) {
if (_is_low_rank)
throw std::invalid_argument("MutliHeadAttention does not support low-rank attention with multi-query or GQA");
if (queries_padder)
queries_padder->add_padding(fused_proj);

Expand All @@ -419,8 +430,14 @@ namespace ctranslate2 {
}

} else {
split_heads(fused_proj, 3 * _num_heads, queries_padder);
ops::Split(1)(fused_proj, queries_proj, keys_proj, values_proj);
if (!_is_low_rank) {
split_heads(fused_proj, 3 * _num_heads, queries_padder);
ops::Split(1)(fused_proj, queries_proj, keys_proj, values_proj);
} else {
split_heads(queries_proj, _num_heads, queries_padder);
split_heads(keys_proj, _num_heads_kv, queries_padder);
split_heads(values_proj, _num_heads_kv, queries_padder);
}
}

if (_rotary_embeddings) {
Expand Down
24 changes: 20 additions & 4 deletions src/layers/attention_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,25 @@ namespace ctranslate2 {
return alibi;
}

static bool set_low_rank(const models::Model& model, const std::string& scope) {
const StorageView* low_rank_weight = model.get_variable_if_exists(scope + "/linear_0/low_rank_weight_1");
if (low_rank_weight) {
return true;
}
return false;
}

static std::vector<Dense> make_linear_layers(const models::Model& model,
const std::string& scope,
bool self_attention) {
const dim_t num_linear_layers = self_attention ? 2 : 3;
bool self_attention,
bool _is_low_rank) {
dim_t num_linear_layers;
if (!_is_low_rank) {
num_linear_layers = self_attention ? 2 : 3;
} else {
num_linear_layers = 4;
}

std::vector<Dense> layers;
layers.reserve(num_linear_layers);
for (dim_t i = 0; i < num_linear_layers; ++i)
Expand Down Expand Up @@ -117,11 +132,12 @@ namespace ctranslate2 {
bool is_decoder,
Alibi* alibi,
bool is_flash_attn)
: _tensor_parallel(model.tensor_parallel())
: _is_low_rank(set_low_rank(model, scope))
, _tensor_parallel(model.tensor_parallel())
, _num_heads(_tensor_parallel ? SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()) : num_heads)
, _self_attention(self_attention)
, _is_decoder(is_decoder)
, _linear(make_linear_layers(model, scope, self_attention))
, _linear(make_linear_layers(model, scope, self_attention, _is_low_rank))
, _d_model(_tensor_parallel ? SAFE_DIVIDE(_linear.back().output_size(), ScopedMPISetter::getNRanks()) : _linear.back().output_size())
, _d_head(model.get_attribute_with_default<int32_t >(scope + "/head_dim", _d_model / _num_heads))
, _pre_norm(pre_norm)
Expand Down
Loading
Loading