Skip to content

Commit 1cfabe0

Browse files
authored
Merge branch 'OpenNMT:master' into master
2 parents a55386b + 2870fe3 commit 1cfabe0

File tree

7 files changed

+224
-93
lines changed

7 files changed

+224
-93
lines changed

.github/workflows/ci.yml

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
backend: [mkl, dnnl]
2222

2323
steps:
24-
- uses: actions/checkout@v3
24+
- uses: actions/checkout@v4
2525
with:
2626
submodules: recursive
2727

@@ -82,7 +82,7 @@ jobs:
8282
backend: [openblas]
8383

8484
steps:
85-
- uses: actions/checkout@v3
85+
- uses: actions/checkout@v4
8686
with:
8787
submodules: recursive
8888

@@ -137,11 +137,11 @@ jobs:
137137
include:
138138
- os: ubuntu-20.04
139139
arch: aarch64
140-
- os: macos-12
140+
- os: macos-13
141141
arch: arm64
142142

143143
steps:
144-
- uses: actions/checkout@v3
144+
- uses: actions/checkout@v4
145145
with:
146146
submodules: recursive
147147

@@ -150,7 +150,7 @@ jobs:
150150
name: Set up QEMU
151151

152152
- name: Build wheels
153-
uses: pypa/cibuildwheel@v2.16.5
153+
uses: pypa/cibuildwheel@v2.21.3
154154
with:
155155
package-dir: python
156156
output-dir: python/wheelhouse
@@ -168,9 +168,9 @@ jobs:
168168
CIBW_SKIP: pp* *-musllinux_*
169169

170170
- name: Upload Python wheels
171-
uses: actions/upload-artifact@v3
171+
uses: actions/upload-artifact@v4
172172
with:
173-
name: python-wheels
173+
name: python-wheels-${{ runner.os }}-${{ matrix.arch }}
174174
path: python/wheelhouse
175175

176176

@@ -185,21 +185,23 @@ jobs:
185185

186186
steps:
187187
- name: Set up Python 3.8
188-
uses: actions/setup-python@v4
188+
uses: actions/setup-python@v5
189189
with:
190190
python-version: 3.8
191191

192-
- uses: actions/checkout@v3
192+
- uses: actions/checkout@v4
193193

194194
- name: Prepare test environment
195195
shell: bash
196196
run: |
197197
./python/tools/prepare_test_environment.sh
198198
199199
- name: Download Python wheels
200-
uses: actions/download-artifact@v3
200+
uses: actions/download-artifact@v4
201201
with:
202-
name: python-wheels
202+
pattern: python-wheels-${{ runner.os }}-*
203+
merge-multiple: true
204+
path: .
203205

204206
- name: Install wheel
205207
if: startsWith(matrix.os, 'ubuntu')
@@ -222,10 +224,10 @@ jobs:
222224
runs-on: ubuntu-latest
223225

224226
steps:
225-
- uses: actions/checkout@v3
227+
- uses: actions/checkout@v4
226228

227229
- name: Set up Python 3.8
228-
uses: actions/setup-python@v4
230+
uses: actions/setup-python@v5
229231
with:
230232
python-version: 3.8
231233

@@ -257,9 +259,11 @@ jobs:
257259

258260
steps:
259261
- name: Download Python wheels
260-
uses: actions/download-artifact@v3
262+
uses: actions/download-artifact@v4
261263
with:
262-
name: python-wheels
264+
pattern: python-wheels-*
265+
merge-multiple: true
266+
path: .
263267

264268
- name: Publish Python wheels to PyPI
265269
uses: pypa/gh-action-pypi-publish@release/v1
@@ -272,7 +276,7 @@ jobs:
272276
build-and-push-docker-images:
273277
runs-on: ubuntu-20.04
274278
steps:
275-
- uses: actions/checkout@v3
279+
- uses: actions/checkout@v4
276280
with:
277281
submodules: recursive
278282

@@ -299,17 +303,19 @@ jobs:
299303
needs: [check-python-style, build-python-wheels]
300304

301305
steps:
302-
- uses: actions/checkout@v3
306+
- uses: actions/checkout@v4
303307

304308
- name: Set up Python 3.8
305-
uses: actions/setup-python@v4
309+
uses: actions/setup-python@v5
306310
with:
307311
python-version: 3.8
308312

309313
- name: Download CTranslate2 wheels
310-
uses: actions/download-artifact@v3
314+
uses: actions/download-artifact@v4
311315
with:
312-
name: python-wheels
316+
pattern: python-wheels-${{ runner.os }}-*
317+
merge-multiple: true
318+
path: .
313319

314320
- name: Install CTranslate2 wheel
315321
run: |

CMakeLists.txt

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -558,39 +558,9 @@ if (WITH_CUDA)
558558
else()
559559
list(APPEND LIBRARIES ${CUDA_CUBLAS_LIBRARIES})
560560
endif()
561-
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
562-
cuda_add_library(${PROJECT_NAME}
563-
${SOURCES}
564-
src/cuda/allocator.cc
565-
src/cuda/primitives.cu
566-
src/cuda/random.cu
567-
src/cuda/utils.cc
568-
src/ops/alibi_add_gpu.cu
569-
src/ops/bias_add_gpu.cu
570-
src/ops/concat_split_slide_gpu.cu
571-
src/ops/conv1d_gpu.cu
572-
src/ops/dequantize_gpu.cu
573-
src/ops/flash_attention_gpu.cu
574-
src/ops/gather_gpu.cu
575-
src/ops/gumbel_max_gpu.cu
576-
src/ops/layer_norm_gpu.cu
577-
src/ops/mean_gpu.cu
578-
src/ops/multinomial_gpu.cu
579-
src/ops/rms_norm_gpu.cu
580-
src/ops/rotary_gpu.cu
581-
src/ops/softmax_gpu.cu
582-
src/ops/tile_gpu.cu
583-
src/ops/topk_gpu.cu
584-
src/ops/topp_mask_gpu.cu
585-
src/ops/quantize_gpu.cu
586-
src/ops/nccl_ops_gpu.cu
587-
src/ops/awq/gemm_gpu.cu
588-
src/ops/awq/gemv_gpu.cu
589-
src/ops/awq/dequantize_gpu.cu
590-
)
591561
if (WITH_FLASH_ATTN)
592562
add_definitions(-DCT2_WITH_FLASH_ATTN)
593-
cuda_add_library(${PROJECT_NAME}
563+
list(APPEND SOURCES
594564
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
595565
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
596566
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
@@ -660,6 +630,36 @@ if (WITH_CUDA)
660630
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
661631
PROPERTIES COMPILE_FLAGS "--use_fast_math")
662632
endif()
633+
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
634+
cuda_add_library(${PROJECT_NAME}
635+
${SOURCES}
636+
src/cuda/allocator.cc
637+
src/cuda/primitives.cu
638+
src/cuda/random.cu
639+
src/cuda/utils.cc
640+
src/ops/alibi_add_gpu.cu
641+
src/ops/bias_add_gpu.cu
642+
src/ops/concat_split_slide_gpu.cu
643+
src/ops/conv1d_gpu.cu
644+
src/ops/dequantize_gpu.cu
645+
src/ops/flash_attention_gpu.cu
646+
src/ops/gather_gpu.cu
647+
src/ops/gumbel_max_gpu.cu
648+
src/ops/layer_norm_gpu.cu
649+
src/ops/mean_gpu.cu
650+
src/ops/multinomial_gpu.cu
651+
src/ops/rms_norm_gpu.cu
652+
src/ops/rotary_gpu.cu
653+
src/ops/softmax_gpu.cu
654+
src/ops/tile_gpu.cu
655+
src/ops/topk_gpu.cu
656+
src/ops/topp_mask_gpu.cu
657+
src/ops/quantize_gpu.cu
658+
src/ops/nccl_ops_gpu.cu
659+
src/ops/awq/gemm_gpu.cu
660+
src/ops/awq/gemv_gpu.cu
661+
src/ops/awq/dequantize_gpu.cu
662+
)
663663

664664

665665
elseif(WITH_CUDNN)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The project implements a custom runtime that applies many performance optimizati
99
The following model types are currently supported:
1010

1111
* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper
12-
* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon
12+
* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2
1313
* Encoder-only models: BERT, DistilBERT, XLM-RoBERTa
1414

1515
Compatible models should be first converted into an optimized model format. The library includes converters for multiple frameworks:

include/ctranslate2/layers/wav2vec2.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <optional>
34
#include "ctranslate2/layers/transformer.h"
45

56
namespace ctranslate2 {
@@ -81,17 +82,18 @@ namespace ctranslate2 {
8182
}
8283

8384
private:
84-
const Wav2Vec2LayerNormConvLayer _feat_layer0;
85-
const std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>> _feat_layers;
86-
const LayerNorm _fp_norm;
87-
const Dense _fp_ff;
88-
const Wav2Vec2PosConvLayer _pos_conv_embed;
85+
const StorageView* _upgraded_model;
86+
std::optional<Wav2Vec2LayerNormConvLayer> _feat_layer0;
87+
std::optional<std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>>> _feat_layers;
88+
std::optional<LayerNorm> _fp_norm;
89+
std::optional<Dense> _fp_ff;
90+
std::optional<Wav2Vec2PosConvLayer> _pos_conv_embed;
8991
const ops::Transpose _transpose;
9092
const ops::GELU _gelu;
9193
const dim_t _num_heads;
9294
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
9395
const LayerNorm _output_norm;
94-
const Dense _lm_head;
96+
std::optional<Dense> _lm_head;
9597
};
9698

9799
}

python/cpp/wav2vec2.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ namespace ctranslate2 {
8686
Encodes the input features.
8787
8888
Arguments:
89-
features: Mel spectogram of the audio, as a float array with shape
90-
``[batch_size, 80, 3000]``.
89+
features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or
90+
raw audio, as a float array with shape (followed by VAD)
91+
``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]``
9192
to_cpu: Copy the encoder output to the CPU before returning the value.
9293
9394
Returns:

python/ctranslate2/converters/transformers.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,114 @@ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
19561956
gc.collect()
19571957

19581958

1959+
@register_loader("Qwen2Config")
1960+
class Qwen2Loader(ModelLoader):
1961+
@property
1962+
def architecture_name(self):
1963+
return "Qwen2ForCausalLM"
1964+
1965+
def get_model_spec(self, model):
1966+
num_layers = model.config.num_hidden_layers
1967+
1968+
num_heads = model.config.num_attention_heads
1969+
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
1970+
if num_heads_kv == num_heads:
1971+
num_heads_kv = None
1972+
1973+
rope_scaling = getattr(model.config, "rope_scaling", None)
1974+
if rope_scaling:
1975+
rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
1976+
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
1977+
rotary_scaling_factor = rope_scaling["factor"]
1978+
1979+
if rotary_scaling_type is None:
1980+
raise NotImplementedError(
1981+
"RoPE scaling type '%s' is not yet implemented. "
1982+
"The following RoPE scaling types are currently supported: %s"
1983+
% (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
1984+
)
1985+
else:
1986+
rotary_scaling_type = None
1987+
rotary_scaling_factor = 1
1988+
1989+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1990+
num_layers,
1991+
num_heads,
1992+
activation=common_spec.Activation.SWISH,
1993+
pre_norm=True,
1994+
ffn_glu=True,
1995+
rms_norm=True,
1996+
rotary_dim=0,
1997+
rotary_interleave=False,
1998+
rotary_scaling_type=rotary_scaling_type,
1999+
rotary_scaling_factor=rotary_scaling_factor,
2000+
rotary_base=getattr(model.config, "rope_theta", 10000),
2001+
num_heads_kv=num_heads_kv,
2002+
)
2003+
2004+
self.set_decoder(spec.decoder, model.model)
2005+
self.set_linear(spec.decoder.projection, model.lm_head)
2006+
return spec
2007+
2008+
def get_vocabulary(self, model, tokenizer):
2009+
tokens = super().get_vocabulary(model, tokenizer)
2010+
2011+
extra_ids = model.config.vocab_size - len(tokens)
2012+
for i in range(extra_ids):
2013+
tokens.append("<extra_id_%d>" % i)
2014+
return tokens
2015+
2016+
def set_vocabulary(self, spec, tokens):
2017+
spec.register_vocabulary(tokens)
2018+
2019+
def set_config(self, config, model, tokenizer):
2020+
config.bos_token = (
2021+
tokenizer.bos_token
2022+
if tokenizer.bos_token is not None
2023+
else tokenizer.pad_token
2024+
)
2025+
config.eos_token = tokenizer.eos_token
2026+
config.unk_token = (
2027+
tokenizer.unk_token if tokenizer.unk_token is not None else ""
2028+
)
2029+
config.layer_norm_epsilon = model.config.rms_norm_eps
2030+
2031+
def set_layer_norm(self, spec, layer_norm):
2032+
spec.gamma = layer_norm.weight
2033+
2034+
def set_decoder(self, spec, module):
2035+
spec.scale_embeddings = False
2036+
self.set_embeddings(spec.embeddings, module.embed_tokens)
2037+
self.set_layer_norm(spec.layer_norm, module.norm)
2038+
2039+
for layer_spec, layer in zip(spec.layer, module.layers):
2040+
self.set_layer_norm(
2041+
layer_spec.self_attention.layer_norm, layer.input_layernorm
2042+
)
2043+
self.set_layer_norm(
2044+
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
2045+
)
2046+
2047+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
2048+
self.set_linear(split_layers[0], layer.self_attn.q_proj)
2049+
self.set_linear(split_layers[1], layer.self_attn.k_proj)
2050+
self.set_linear(split_layers[2], layer.self_attn.v_proj)
2051+
2052+
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2053+
self.set_linear(
2054+
layer_spec.self_attention.linear[1],
2055+
layer.self_attn.o_proj,
2056+
)
2057+
2058+
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
2059+
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
2060+
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
2061+
2062+
delattr(layer, "self_attn")
2063+
delattr(layer, "mlp")
2064+
gc.collect()
2065+
2066+
19592067
@register_loader("MixFormerSequentialConfig")
19602068
class MixFormerSequentialLoader(ModelLoader):
19612069
@property

0 commit comments

Comments
 (0)