Skip to content

Commit 1b477a9

Browse files
authored
Merge pull request #108 from FengZiYjun/trainer
FastNLP v0.2
2 parents 15262bd + db0a789 commit 1b477a9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+5432
-1731
lines changed

README.md

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,39 @@
66
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
77
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)
88

9-
fastNLP is a modular Natural Language Processing system based on PyTorch, for fast development of NLP tools. It divides the NLP model based on deep learning into different modules. These modules fall into 4 categories: encoder, interaction, aggregation and decoder, while each category contains different implemented modules. Encoder modules encode the input into some abstract representation, interaction modules make the information in the representation interact with each other, aggregation modules aggregate and reduce information, and decoder modules decode the representation into the output. Most current NLP models could be built on these modules, which vastly simplifies the process of developing NLP models. The architecture of fastNLP is as the figure below:
9+
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models.
1010

11-
![](https://github.com/fastnlp/fastNLP/raw/master/docs/source/figures/procedures.PNG)
12-
![](https://github.com/fastnlp/fastNLP/raw/master/docs/source/figures/text_classification.png)
11+
A deep learning NLP model is the composition of three types of modules:
12+
<table>
13+
<tr>
14+
<td><b> module type </b></td>
15+
<td><b> functionality </b></td>
16+
<td><b> example </b></td>
17+
</tr>
18+
<tr>
19+
<td> encoder </td>
20+
<td> encode the input into some abstract representation </td>
21+
<td> embedding, RNN, CNN, transformer
22+
</tr>
23+
<tr>
24+
<td> aggregator </td>
25+
<td> aggregate and reduce information </td>
26+
<td> self-attention, max-pooling </td>
27+
</tr>
28+
<tr>
29+
<td> decoder </td>
30+
<td> decode the representation into the output </td>
31+
<td> MLP, CRF </td>
32+
</tr>
33+
34+
For example:
35+
36+
![](docs/source/figures/text_classification.png)
1337

1438
## Requirements
1539

1640
- numpy>=1.14.2
1741
- torch>=0.4.0
18-
- torchvision>=0.1.8
1942
- tensorboardX
2043

2144

@@ -39,12 +62,12 @@ pip install fastNLP
3962
<td> an open-source NLP library </td>
4063
</tr>
4164
<tr>
42-
<td><b> fastNLP.core </b></td>
43-
<td> trainer, tester, predictor </td>
65+
<td><b> fastNLP.api </b></td>
66+
<td> APIs for end-to-end prediction </td>
4467
</tr>
4568
<tr>
46-
<td><b> fastNLP.loader </b></td>
47-
<td> all kinds of loaders/readers </td>
69+
<td><b> fastNLP.core </b></td>
70+
<td> data representation & train/test presedure </td>
4871
</tr>
4972
<tr>
5073
<td><b> fastNLP.models </b></td>
@@ -55,11 +78,7 @@ pip install fastNLP
5578
<td> a collection of PyTorch sub-models/components/wheels </td>
5679
</tr>
5780
<tr>
58-
<td><b> fastNLP.saver </b></td>
59-
<td> all kinds of savers/writers </td>
60-
</tr>
61-
<tr>
62-
<td><b> fastNLP.fastnlp </b></td>
63-
<td> a high-level interface for prediction </td>
81+
<td><b> fastNLP.io </b></td>
82+
<td> readers & savers </td>
6483
</tr>
6584
</table>

docs/quick_tutorial.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
# FastNLP Quick Tutorial
1+
# FastNLP Quick Tutorial
2+
18.9 KB
Loading

fastNLP/api/model_zoo.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import torch
2-
31
import hashlib
42
import os
53
import re
64
import shutil
75
import sys
86
import tempfile
97

8+
import torch
9+
1010
try:
1111
from requests.utils import urlparse
1212
from requests import get as urlopen
@@ -132,7 +132,3 @@ def __exit__(self, exc_type, exc_val, exc_tb):
132132

133133
sys.stderr.write('\n')
134134

135-
136-
if __name__ == '__main__':
137-
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context-4e86fd93.pkl', model_dir='.')
138-
print(type(pipeline))

fastNLP/api/processor.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
import torch
2-
from collections import defaultdict
31
import re
2+
from collections import defaultdict
3+
4+
import torch
45

5-
from fastNLP.core.dataset import DataSet
6-
from fastNLP.core.vocabulary import Vocabulary
76
from fastNLP.core.batch import Batch
7+
from fastNLP.core.dataset import DataSet
88
from fastNLP.core.sampler import SequentialSampler
9+
from fastNLP.core.vocabulary import Vocabulary
910

1011

11-
class Processor:
12+
class Processor(object):
1213
def __init__(self, field_name, new_added_field_name):
1314
self.field_name = field_name
1415
if new_added_field_name is None:
@@ -17,7 +18,7 @@ def __init__(self, field_name, new_added_field_name):
1718
self.new_added_field_name = new_added_field_name
1819

1920
def process(self, *args, **kwargs):
20-
pass
21+
raise NotImplementedError
2122

2223
def __call__(self, *args, **kwargs):
2324
return self.process(*args, **kwargs)
@@ -132,27 +133,29 @@ def process(self, dataset):
132133

133134

134135
class IndexerProcessor(Processor):
135-
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False):
136+
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):
136137

137138
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
138139

139140
super(IndexerProcessor, self).__init__(field_name, new_added_field_name)
140141
self.vocab = vocab
141142
self.delete_old_field = delete_old_field
143+
self.is_input = is_input
142144

143145
def set_vocab(self, vocab):
144146
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
145147

146148
self.vocab = vocab
147149

148150
def process(self, dataset):
149-
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
151+
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
150152
for ins in dataset:
151153
tokens = ins[self.field_name]
152154
index = [self.vocab.to_index(token) for token in tokens]
153155
ins[self.new_added_field_name] = index
154156

155-
dataset._set_need_tensor(**{self.new_added_field_name: True})
157+
if self.is_input:
158+
dataset.set_input(self.new_added_field_name)
156159

157160
if self.delete_old_field:
158161
dataset.delete_field(self.field_name)
@@ -161,6 +164,9 @@ def process(self, dataset):
161164

162165

163166
class VocabProcessor(Processor):
167+
"""Build vocabulary with a field in the data set.
168+
169+
"""
164170
def __init__(self, field_name):
165171
super(VocabProcessor, self).__init__(field_name, None)
166172
self.vocab = Vocabulary()
@@ -178,17 +184,20 @@ def get_vocab(self):
178184

179185

180186
class SeqLenProcessor(Processor):
181-
def __init__(self, field_name, new_added_field_name='seq_lens'):
187+
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
182188
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
189+
self.is_input = is_input
183190

184191
def process(self, dataset):
185192
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
186193
for ins in dataset:
187194
length = len(ins[self.field_name])
188195
ins[self.new_added_field_name] = length
189-
dataset._set_need_tensor(**{self.new_added_field_name: True})
196+
if self.is_input:
197+
dataset.set_input(self.new_added_field_name)
190198
return dataset
191199

200+
192201
class ModelProcessor(Processor):
193202
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
194203
"""
@@ -238,6 +247,7 @@ def set_model_device(self, device):
238247
device = torch.device(device)
239248
self.model.to(device)
240249

250+
241251
class Index2WordProcessor(Processor):
242252
def __init__(self, vocab, field_name, new_added_field_name):
243253
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
@@ -251,26 +261,28 @@ def process(self, dataset):
251261

252262

253263
class SetTensorProcessor(Processor):
264+
# TODO: remove it. It is strange.
254265
def __init__(self, field_dict, default=False):
255266
super(SetTensorProcessor, self).__init__(None, None)
256267
self.field_dict = field_dict
257268
self.default = default
258269

259270
def process(self, dataset):
260-
set_dict = {name: self.default for name in dataset.get_fields().keys()}
271+
set_dict = {name: self.default for name in dataset.get_all_fields().keys()}
261272
set_dict.update(self.field_dict)
262273
dataset._set_need_tensor(**set_dict)
263274
return dataset
264275

265276

266277
class SetIsTargetProcessor(Processor):
278+
# TODO; remove it.
267279
def __init__(self, field_dict, default=False):
268280
super(SetIsTargetProcessor, self).__init__(None, None)
269281
self.field_dict = field_dict
270282
self.default = default
271283

272284
def process(self, dataset):
273-
set_dict = {name: self.default for name in dataset.get_fields().keys()}
285+
set_dict = {name: self.default for name in dataset.get_all_fields().keys()}
274286
set_dict.update(self.field_dict)
275287
dataset.set_target(**set_dict)
276288
return dataset

fastNLP/core/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from .batch import Batch
2-
from .dataset import DataSet
2+
# from .dataset import DataSet
33
from .fieldarray import FieldArray
44
from .instance import Instance
5-
from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator
5+
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
6+
from .metrics import AccuracyMetric
7+
from .optimizer import Optimizer, SGD, Adam
68
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
79
from .tester import Tester
810
from .trainer import Trainer
911
from .vocabulary import Vocabulary
10-
from .optimizer import Optimizer
11-
from .loss import Loss
12+
from ..io.dataset_loader import DataSet
13+

fastNLP/core/batch.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23

34

@@ -25,6 +26,7 @@ def __init__(self, dataset, batch_size, sampler, as_numpy=False):
2526
self.as_numpy = as_numpy
2627
self.idx_list = None
2728
self.curidx = 0
29+
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0)
2830

2931
def __iter__(self):
3032
self.idx_list = self.sampler(self.dataset)
@@ -41,11 +43,11 @@ def __next__(self):
4143

4244
indices = self.idx_list[self.curidx:endidx]
4345

44-
for field_name, field in self.dataset.get_fields().items():
46+
for field_name, field in self.dataset.get_all_fields().items():
4547
if field.is_target or field.is_input:
4648
batch = field.get(indices)
4749
if not self.as_numpy:
48-
batch = torch.from_numpy(batch)
50+
batch = to_tensor(batch, field.dtype)
4951
if field.is_target:
5052
batch_y[field_name] = batch
5153
if field.is_input:
@@ -54,3 +56,14 @@ def __next__(self):
5456
self.curidx = endidx
5557

5658
return batch_x, batch_y
59+
60+
def __len__(self):
61+
return self.num_batches
62+
63+
64+
def to_tensor(batch, dtype):
65+
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
66+
batch = torch.LongTensor(batch)
67+
if dtype in (float, np.float32, np.float64):
68+
batch = torch.FloatTensor(batch)
69+
return batch

0 commit comments

Comments
 (0)