Skip to content

Commit 62ae2cb

Browse files
committed
bug fix
2 parents 8f9c633 + e54676c commit 62ae2cb

File tree

9 files changed

+72
-31
lines changed

9 files changed

+72
-31
lines changed

fnlp-core/src/main/java/org/fnlp/data/reader/Reader.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Iterator;
2323

2424
import org.fnlp.ml.types.Instance;
25+
import org.fnlp.ml.types.InstanceSet;
2526

2627
/**
2728
* @author xpqiu
@@ -35,4 +36,16 @@ public abstract class Reader implements Iterator<Instance> {
3536
public void remove () {
3637
throw new IllegalStateException ("This Iterator<Instance> does not support remove().");
3738
}
39+
40+
41+
public InstanceSet read(){
42+
InstanceSet instSet = new InstanceSet();
43+
while (hasNext()) {
44+
Instance inst = next();
45+
if(inst!=null){
46+
instSet.add(inst);
47+
}
48+
}
49+
return instSet;
50+
}
3851
}

fnlp-core/src/main/java/org/fnlp/ml/classifier/AbstractClassifier.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.fnlp.ml.classifier.LabelParser.Type;
2525
import org.fnlp.ml.types.Instance;
2626
import org.fnlp.ml.types.InstanceSet;
27+
import org.fnlp.ml.types.alphabet.AlphabetFactory;
2728

2829
/**
2930
* 分类器抽象类
@@ -35,7 +36,12 @@
3536
public abstract class AbstractClassifier implements Serializable{
3637

3738
private static final long serialVersionUID = -175929257288466023L;
38-
39+
protected AlphabetFactory factory;
40+
41+
42+
public AlphabetFactory getAlphabetFactory() {
43+
return factory;
44+
}
3945

4046
/**
4147
* 返回分类内部结果,标签为内部表示索引,需要还原处理

fnlp-core/src/main/java/org/fnlp/ml/classifier/bayes/BayesClassifier.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
*
3232
*/
3333
public class BayesClassifier extends AbstractClassifier implements Serializable{
34-
protected AlphabetFactory factory;
34+
3535
protected ItemFrequency tf;
3636
protected Pipe pipe;
3737
protected FeatureSelect fs;

fnlp-core/src/main/java/org/fnlp/ml/classifier/linear/Linear.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public class Linear extends AbstractClassifier implements Serializable {
5252
private static final long serialVersionUID = -2626247109469506636L;
5353

5454
protected Inferencer inferencer;
55-
protected AlphabetFactory factory;
55+
5656
protected Pipe pipe;
5757

5858
public Linear(Inferencer inferencer, AlphabetFactory factory) {
@@ -126,9 +126,7 @@ public void setInferencer(Inferencer inferencer) {
126126
this.inferencer = inferencer;
127127
}
128128

129-
public AlphabetFactory getAlphabetFactory() {
130-
return factory;
131-
}
129+
132130

133131
public void setWeights(float[] weights) {
134132
inferencer.setWeights(weights);

fnlp-core/src/main/java/org/fnlp/ml/classifier/linear/OnlineTrainer.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ public class OnlineTrainer extends AbstractTrainer {
7171
protected Random random;
7272

7373
public int iternum;
74-
protected float[] weights;
74+
protected float[] weights;
75+
AlphabetFactory af;
7576

7677
public OnlineTrainer(AlphabetFactory af, int iternum) {
7778
//默认特征生成器
@@ -84,7 +85,8 @@ public OnlineTrainer(AlphabetFactory af, int iternum) {
8485
this.update = new LinearMaxPAUpdate(loss);
8586
this.iternum = iternum;
8687
this.c = 0.1f;
87-
weights = (float[]) inferencer.getWeights();
88+
this.af = af;
89+
weights = (float[]) inferencer.getWeights();
8890
if (weights == null) {
8991
weights = new float[af.getFeatureSize()];
9092
inferencer.setWeights(weights);
@@ -283,7 +285,7 @@ public Linear train(InstanceSet trainset, InstanceSet devset) {
283285
System.out.println("time escape:" + (endTime - beginTime) / 1000.0
284286
+ "s");
285287
System.out.println();
286-
Linear p = new Linear(inferencer, trainset.getAlphabetFactory());
288+
Linear p = new Linear(inferencer,af);
287289
return p;
288290
}
289291

fnlp-core/src/main/java/org/fnlp/ml/eval/Evaluation.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ public Evaluation(InstanceSet test) {
5050
for(int i=0;i<totnum;i++){
5151
golden[i] = (Integer) test.getInstance(i).getTarget();
5252
}
53-
labels = test.getAlphabetFactory().DefaultLabelAlphabet();
54-
numofclass = labels.size();
53+
5554

5655
}
5756

@@ -82,6 +81,8 @@ public void eval2File(AbstractClassifier cl,String path){
8281
* @return
8382
*/
8483
public void eval(AbstractClassifier cl,int nbest){
84+
labels = cl.getAlphabetFactory().DefaultLabelAlphabet();
85+
numofclass = labels.size();
8586
TPredict[] pred = cl.classify(test,nbest);
8687
int[] acc = new int[nbest];
8788
for(int i=0;i<totnum;i++){

fnlp-core/src/main/java/org/fnlp/nlp/parser/dep/DependencyTree.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,13 @@ private void toList(ArrayList<List<String>> lists) {
246246
}
247247

248248
}
249-
249+
//错误
250250
public String[] getWords() {
251251
String[] words = new String[size];
252252
getWords(words);
253253
return words;
254254
}
255-
255+
//错误
256256
private void getWords(String[] words) {
257257
words[id] = word;
258258
for (int i = 0; i < leftChilds.size(); i++) {
@@ -262,6 +262,29 @@ private void getWords(String[] words) {
262262
rightChilds.get(i).getWords(words);
263263
}
264264

265+
}
266+
/**
267+
* 得到依赖类型字符串
268+
* @return
269+
*/
270+
271+
public String getTypes() {
272+
StringBuffer sb = new StringBuffer();
273+
String ste;
274+
String[] str;
275+
for (int i = 0; i < leftChilds.size(); i++) {
276+
sb.append(leftChilds.get(i).getTypes());
277+
}
278+
if(relation!=null)
279+
sb.append(relation);
280+
else
281+
sb.append("核心词");
282+
sb.append(" ");
283+
for (int i = 0; i < rightChilds.size(); i++) {
284+
sb.append(rightChilds.get(i).getTypes());
285+
}
286+
287+
return sb.toString();
265288
}
266289

267290
}

fnlp-core/src/main/java/org/fnlp/nlp/parser/dep/analysis/AnalysisTest.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,5 @@ private void print(){
9292
System.out.printf("rate(UEM):\t%.8f\ttotal(sent):\t%d\n", 1 - 1.0
9393
* errsent / sent_sum, sent_sum);
9494
}
95-
96-
97-
9895

9996
}

fnlp-demo/src/main/java/org/fnlp/demo/nlp/tc/TextClassificationCustom1.java

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,35 +59,36 @@ public static void main(String[] args) throws Exception {
5959
AlphabetFactory af = AlphabetFactory.buildFactory();
6060

6161
//使用n元特征
62-
Pipe ngrampp = new NGram(new int[] {2,3 });
62+
Pipe ngrampp = new NGram(new int[] {1});
6363
//将字符特征转换成字典索引
6464
Pipe indexpp = new StringArray2IndexArray(af);
6565
//将目标值对应的索引号作为类别
6666
Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet());
6767

6868
//建立pipe组合
69-
SeriesPipes pp = new SeriesPipes(new Pipe[]{targetpp,ngrampp,indexpp});
70-
71-
InstanceSet trainset = new InstanceSet(pp,af);
72-
73-
InstanceSet testset = new InstanceSet(pp,af);
69+
SeriesPipes pp = new SeriesPipes(new Pipe[]{ngrampp,indexpp});
7470

7571
//用不同的Reader读取相应格式的文件
76-
Reader reader = new DocumentReader(trainDataPath);
72+
Reader reader = new DocumentReader(trainDataPath);
73+
InstanceSet trainset = reader.read();
7774

78-
//读入数据,并进行数据处理
79-
trainset.loadThruStagePipes(reader);
80-
// af.setStopIncrement(true);
8175
reader = new DocumentReader(testDataPath);
82-
83-
testset.loadThruStagePipes(reader);
76+
InstanceSet testset = reader.read();
77+
78+
targetpp.process(trainset);
79+
targetpp.process(testset);
80+
81+
pp.process(trainset);
82+
af.setStopIncrement(true);
83+
pp.process(testset);
84+
8485

8586

8687
/**
8788
* 建立分类器
8889
*/
89-
OnlineTrainer trainer = new OnlineTrainer(af);
90-
Linear pclassifier = trainer.train(trainset);
90+
OnlineTrainer trainer = new OnlineTrainer(af,50);
91+
Linear pclassifier = trainer.train(trainset,testset);
9192
pp.removeTargetPipe();
9293
pclassifier.setPipe(pp);
9394
af.setStopIncrement(true);
@@ -101,7 +102,7 @@ public static void main(String[] args) throws Exception {
101102

102103
//性能评测
103104
Evaluation eval = new Evaluation(testset);
104-
eval.eval(cl,1);
105+
eval.eval(cl,2);
105106

106107

107108

0 commit comments

Comments
 (0)