Skip to content

Commit b652c23

Browse files
committed
first commit
0 parents  commit b652c23

File tree

9 files changed

+552
-0
lines changed

9 files changed

+552
-0
lines changed

.gitignore

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
.pytest_cache/
4+
*.py[cod]
5+
*$py.class
6+
7+
# C extensions
8+
*.so
9+
10+
# Distribution / packaging
11+
.Python
12+
build/
13+
develop-eggs/
14+
dist/
15+
downloads/
16+
eggs/
17+
.eggs/
18+
lib/
19+
lib64/
20+
parts/
21+
sdist/
22+
var/
23+
wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.coverage
43+
.coverage.*
44+
.cache
45+
nosetests.xml
46+
coverage.xml
47+
*.cover
48+
.hypothesis/
49+
50+
# Translations
51+
*.mo
52+
*.pot
53+
54+
# Django stuff:
55+
*.log
56+
.static_storage/
57+
.media/
58+
local_settings.py
59+
60+
# Flask stuff:
61+
instance/
62+
.webassets-cache
63+
64+
# Scrapy stuff:
65+
.scrapy
66+
67+
# Sphinx documentation
68+
docs/_build/
69+
70+
# PyBuilder
71+
target/
72+
73+
# Jupyter Notebook
74+
.ipynb_checkpoints
75+
76+
# pyenv
77+
.python-version
78+
79+
# celery beat schedule file
80+
celerybeat-schedule
81+
82+
# SageMath parsed files
83+
*.sage.py
84+
85+
# Environments
86+
.env
87+
.venv
88+
env/
89+
venv/
90+
ENV/
91+
env.bak/
92+
venv.bak/
93+
94+
# Spyder project settings
95+
.spyderproject
96+
.spyproject
97+
98+
# Rope project settings
99+
.ropeproject
100+
101+
# mkdocs documentation
102+
/site
103+
104+
# mypy
105+
.mypy_cache/

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
The MIT License (MIT)
2+
3+
Copyright (C) 2017 Ines Montani
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in
13+
all copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
THE SOFTWARE.

README.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# spacybert: Bert inference for spaCy
2+
[spaCy v2.0](https://spacy.io/usage/v2) extension and pipeline component for loading BERT sentence / document embedding meta data to `Doc`, `Span` and `Token` objects. The Bert backend itself is supported by the [Hugging Face transformers](https://github.com/huggingface/transformers) library.
3+
4+
## Installation
5+
`spacybert` requires `spacy` v2.0.0 or higher.
6+
7+
## Usage
8+
### Getting BERT embeddings for single language dataset
9+
```
10+
import spacy
11+
from spacybert import BertInference
12+
nlp = spacy.load('en')
13+
```
14+
15+
Then either use BertInference as part of a pipeline,
16+
```
17+
bert = BertInference(
18+
from_pretrained='path/to/pretrained_bert_weights_dir',
19+
set_extension=False)
20+
nlp.add_pipe(bert, last=True)
21+
```
22+
Or not...
23+
```
24+
bert = BertInference(
25+
from_pretrained='path/to/pretrained_bert_weights_dir',
26+
set_extension=True)
27+
```
28+
The difference is that when `set_extension=True`, `bert_repr` is set as a property extension for the Doc, Span and Token spacy objects. If `set_extension=False`, the `bert_repr` is set as an attribute extension with a default value (`=None`). The attribute computes the correct value when `doc._.bert_repr` is called.
29+
30+
Get the Bert representation / embedding.
31+
```
32+
doc = nlp("This is a test")
33+
print(doc._.bert_repr) # <-- torch.Tensor
34+
```
35+
36+
### Getting BERT embeddings for multiple languages dataset.
37+
```
38+
import spacy
39+
from spacy_langdetect import LanguageDetector
40+
from spacybert import MultiLangBertInference
41+
42+
nlp = spacy.load('en')
43+
nlp.add_pipe(LanguageDetector(), name='language_detector', last=True)
44+
bert = MultiLangBertInference(
45+
from_pretrained={
46+
'en': 'path/to/en_pretrained_bert_weights_dir',
47+
'nl': 'path/to/nl_pretrained_bert_weights_dir'
48+
},
49+
set_extension=False)
50+
nlp.add_pipe(bert, after='language_detector')
51+
52+
texts = [
53+
"This is a test", # English
54+
"Dit is een test" # Dutch
55+
]
56+
for doc in nlp.pipe(texts):
57+
print(doc._.bert_repr) # <-- torch.Tensor
58+
```
59+
When language_detector detects languages other than the ones for which pre-trained weights is specified, by default `doc._.bert_repr = None`.
60+
61+
## Available attributes
62+
The extension sets attributes on the `Doc`, `Span` and `Token`. You can change the attribute name on initializing the extension.
63+
| | | |
64+
|-|-|-|
65+
| `Doc._.bert_repr` | `torch.Tensor` | Document BERT embedding |
66+
| `Span._.bert_repr` | `torch.Tensor` | Span BERT embedding |
67+
| `Token._.bert_repr` | `torch.Tensor` | Token BERT embedding |
68+
| | | |
69+
70+
## Settings
71+
On initialization of `BertInference`, you can define the following:
72+
73+
| name | type | default | description |
74+
|-|-|-|-|
75+
| `from_pretrained` | `str` | `None` | Path to Bert model directory or name of HuggingFace transformers pre-trained Bert weights, e.g., `bert-base-uncased` |
76+
| `attr_name` | `str` | `'bert_repr'` | Name of the BERT embedding attribute to set to the `._` property |
77+
| `max_seq_len` | `int` | 512 | Max sequence length for input to Bert |
78+
| `pooling_strategy` | `str` | `'REDUCE_MEAN'` | Strategy to generate single sentence embedding from multiple word embeddings. See below for the various pooling strategies available. |
79+
| `set_extension` | `bool` | `True` | If `True`, then `'bert_repr'` is set as a property extension for the `Doc`, `Span` and `Token` spacy objects. If `False`, the `'bert_repr'` is set as an attribute extension with a default value (`None`) which gets filled correctly when called in a pipeline. Set it to `False` if you want to use this extension in a spacy pipeline. |
80+
| `force_extension` | `bool` | `True` | A boolean value to create the same 'Extension Attribute' upon being executed again |
81+
82+
On initialization of `MultiLangBertInference`, you can define the following:
83+
84+
| name | type | default | description |
85+
|-|-|-|-|
86+
| `from_pretrained` | `Dict[LANG_ISO_639_1, str]` | `None` | Mapping between two-letter language codes to path to model directory or HuggingFace transformers pre-trained Bert weights |
87+
| `attr_name` | `str` | `'bert_repr'` | Same as in BertInference |
88+
| `max_seq_len` | `int` | 512 | Same as in BertInference |
89+
| `pooling_strategy` | `str` | `'REDUCE_MEAN'` | Same as in BertInference |
90+
| `set_extension` | `bool` | `True` | Same as in BertInference |
91+
| `force_extension` | `bool` | `True` | Same as in BertInference |
92+
93+
## Pooling strategies
94+
| strategy | description |
95+
|-|-|
96+
| `REDUCE_MEAN` | Element-wise average the word embeddings |
97+
| `REDUCE_MAX` | Element-wise maximum of the word embeddings |
98+
| `REDUCE_MEAN_MAX` | Apply both `'REDUCE_MEAN'` and `'REDUCE_MAX'` and concatenate. So if the original word embedding is of dimensions `(768,)`, then the output will have shape `(1536,)` |
99+
| `CLS_TOKEN`, `FIRST_TOKEN` | Take the embedding of only the first `[CLS]` token |
100+
| `SEP_TOKEN`, `LAST_TOKEN` | Take the embedding of only the last `[SEP]` token |
101+
| `None` | No reduction is applied and a matrix of embeddings per word in the sentence is returned |
102+
103+
## Roadmap
104+
This extension is still experimental. Possible future updates include:
105+
* Getting document representation from other state-of-the-art NLP models other than Google's BERT.
106+
* Method for computing similarity between `Doc`, `Span` and `Token` objects using the `bert_repr` tensor.
107+
* Getting representation from multiple / other layers in the models.

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch>=1.4.0
2+
transformers>=3.0.0
3+
spacy>=2.2.1,<3.0.0
4+
spacy-langdetect>=0.1.2

setup.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from pathlib import Path
2+
from setuptools import setup, find_packages
3+
4+
package_name = 'spacybert'
5+
root = Path(__file__).parent.resolve()
6+
7+
# Read in package meta from about.py
8+
about_path = root / package_name / 'about.py'
9+
with about_path.open('r', encoding='utf8') as f:
10+
about = {}
11+
exec(f.read(), about)
12+
13+
# Get readme
14+
readme_path = root / 'README.md'
15+
with readme_path.open('r', encoding='utf8') as f:
16+
readme = f.read()
17+
18+
install_requires = [
19+
'torch>=1.4.0',
20+
'transformers>=3.0.0',
21+
'spacy>=2.2.1,<3.0.0',
22+
'spacy-langdetect>=0.1.2'
23+
]
24+
test_requires = ['pytest']
25+
26+
setup(
27+
name=package_name,
28+
description=about['__summary__'],
29+
long_description=readme,
30+
author=about['__author__'],
31+
author_email=about['__email__'],
32+
url=about['__uri__'],
33+
version=about['__version__'],
34+
license=about['__license__'],
35+
packages=find_packages(),
36+
install_requires=install_requires,
37+
test_requires=test_requires,
38+
zip_safe=False,
39+
)

0 commit comments

Comments
 (0)