Skip to content

Commit b4132be

Browse files
authored
fix mobilebert register (#1890)
1 parent 58cea12 commit b4132be

File tree

3 files changed

+180
-1
lines changed

3 files changed

+180
-1
lines changed

mindnlp/transformers/models/auto/configuration_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
("mistral", "MistralConfig"),
149149
("mixtral", "MixtralConfig"),
150150
("mllama", "MllamaConfig"),
151+
("mobilebert", "MobileBertConfig"),
151152
("mobilevit", "MobileViTConfig"),
152153
("mobilenet_v1", "MobileNetV1Config"),
153154
("mobilenet_v2", "MobileNetV2Config"),

mindnlp/transformers/models/mobilebert/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
"""
1818
MobileBert Models init
1919
"""
20-
from . import configuration_mobilebert, modeling_mobilebert, tokenization_mobilebert
20+
from . import configuration_mobilebert, modeling_mobilebert, tokenization_mobilebert, tokenization_mobilebert_fast
2121

2222
from .modeling_mobilebert import *
2323
from .configuration_mobilebert import *
2424
from .tokenization_mobilebert import *
25+
from .tokenization_mobilebert_fast import *
2526

2627
__all__ = []
2728
__all__.extend(modeling_mobilebert.__all__)
2829
__all__.extend(configuration_mobilebert.__all__)
2930
__all__.extend(tokenization_mobilebert.__all__)
31+
__all__.extend(tokenization_mobilebert_fast.__all__)
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# coding=utf-8
2+
#
3+
# Copyright 2020 The HuggingFace Team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""Tokenization classes for MobileBERT."""
17+
18+
import json
19+
from typing import List, Optional, Tuple
20+
21+
from tokenizers import normalizers
22+
23+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
24+
from .tokenization_mobilebert import MobileBertTokenizer
25+
from ....utils import logging
26+
27+
28+
logger = logging.get_logger(__name__)
29+
30+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
31+
32+
33+
# Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast with BERT->MobileBERT,Bert->MobileBert
34+
class MobileBertTokenizerFast(PreTrainedTokenizerFast):
35+
r"""
36+
Construct a "fast" MobileBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
37+
38+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
39+
refer to this superclass for more information regarding those methods.
40+
41+
Args:
42+
vocab_file (`str`):
43+
File containing the vocabulary.
44+
do_lower_case (`bool`, *optional*, defaults to `True`):
45+
Whether or not to lowercase the input when tokenizing.
46+
unk_token (`str`, *optional*, defaults to `"[UNK]"`):
47+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48+
token instead.
49+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
50+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
51+
sequence classification or for a text and a question for question answering. It is also used as the last
52+
token of a sequence built with special tokens.
53+
pad_token (`str`, *optional*, defaults to `"[PAD]"`):
54+
The token used for padding, for example when batching sequences of different lengths.
55+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
56+
The classifier token which is used when doing sequence classification (classification of the whole sequence
57+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
58+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
59+
The token used for masking values. This is the token used when training this model with masked language
60+
modeling. This is the token which the model will try to predict.
61+
clean_text (`bool`, *optional*, defaults to `True`):
62+
Whether or not to clean the text before tokenization by removing any control characters and replacing all
63+
whitespaces by the classic one.
64+
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
65+
Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
66+
issue](https://github.com/huggingface/transformers/issues/328)).
67+
strip_accents (`bool`, *optional*):
68+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
69+
value for `lowercase` (as in the original MobileBERT).
70+
wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
71+
The prefix for subwords.
72+
"""
73+
74+
vocab_files_names = VOCAB_FILES_NAMES
75+
slow_tokenizer_class = MobileBertTokenizer
76+
77+
def __init__(
78+
self,
79+
vocab_file=None,
80+
tokenizer_file=None,
81+
do_lower_case=True,
82+
unk_token="[UNK]",
83+
sep_token="[SEP]",
84+
pad_token="[PAD]",
85+
cls_token="[CLS]",
86+
mask_token="[MASK]",
87+
tokenize_chinese_chars=True,
88+
strip_accents=None,
89+
**kwargs,
90+
):
91+
super().__init__(
92+
vocab_file,
93+
tokenizer_file=tokenizer_file,
94+
do_lower_case=do_lower_case,
95+
unk_token=unk_token,
96+
sep_token=sep_token,
97+
pad_token=pad_token,
98+
cls_token=cls_token,
99+
mask_token=mask_token,
100+
tokenize_chinese_chars=tokenize_chinese_chars,
101+
strip_accents=strip_accents,
102+
**kwargs,
103+
)
104+
105+
normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
106+
if (
107+
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
108+
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
109+
or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
110+
):
111+
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
112+
normalizer_state["lowercase"] = do_lower_case
113+
normalizer_state["strip_accents"] = strip_accents
114+
normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
115+
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
116+
117+
self.do_lower_case = do_lower_case
118+
119+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
120+
"""
121+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
122+
adding special tokens. A MobileBERT sequence has the following format:
123+
124+
- single sequence: `[CLS] X [SEP]`
125+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
126+
127+
Args:
128+
token_ids_0 (`List[int]`):
129+
List of IDs to which the special tokens will be added.
130+
token_ids_1 (`List[int]`, *optional*):
131+
Optional second list of IDs for sequence pairs.
132+
133+
Returns:
134+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
135+
"""
136+
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
137+
138+
if token_ids_1 is not None:
139+
output += token_ids_1 + [self.sep_token_id]
140+
141+
return output
142+
143+
def create_token_type_ids_from_sequences(
144+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
145+
) -> List[int]:
146+
"""
147+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A MobileBERT sequence
148+
pair mask has the following format:
149+
150+
```
151+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
152+
| first sequence | second sequence |
153+
```
154+
155+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
156+
157+
Args:
158+
token_ids_0 (`List[int]`):
159+
List of IDs.
160+
token_ids_1 (`List[int]`, *optional*):
161+
Optional second list of IDs for sequence pairs.
162+
163+
Returns:
164+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
165+
"""
166+
sep = [self.sep_token_id]
167+
cls = [self.cls_token_id]
168+
if token_ids_1 is None:
169+
return len(cls + token_ids_0 + sep) * [0]
170+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
171+
172+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
173+
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
174+
return tuple(files)
175+
176+
__all__ = ['MobileBertTokenizerFast']

0 commit comments

Comments
 (0)