Skip to content

Commit aef1da6

Browse files
committed
bayes&knn
全部删除重新添加了一下,不知道这次会不会有冲突
1 parent f047365 commit aef1da6

18 files changed

+1884
-3
lines changed

fnlp-core/src/main/java/org/fnlp/ml/classifier/LabelParser.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ public static Predict parse(TPredict res,
5656
break;
5757
case STRING:
5858
pred = new Predict<String>(n);
59-
for(int i=0;i<n;i++){
59+
for(int i=0;i<n;i++){
60+
if(res.getLabel(i)==null){
61+
pred.set(i, "null", 0f);
62+
continue;
63+
}
6064
int idx = (Integer) res.getLabel(i);
6165
String l = labels.lookupString(idx);
6266
pred.set(i, l, res.getScore(i));

fnlp-core/src/main/java/org/fnlp/ml/classifier/Predict.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public class Predict<T> implements TPredict<T> {
3838
/**
3939
* 保存个数
4040
*/
41-
int n;
41+
protected int n;
4242
/**
4343
* 缺省只保存一个最大值
4444
*/
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
package org.fnlp.ml.classifier.bayes;
2+
3+
import gnu.trove.iterator.TIntFloatIterator;
4+
5+
import java.io.BufferedInputStream;
6+
import java.io.BufferedOutputStream;
7+
import java.io.File;
8+
import java.io.FileInputStream;
9+
import java.io.FileOutputStream;
10+
import java.io.IOException;
11+
import java.io.ObjectInputStream;
12+
import java.io.ObjectOutputStream;
13+
import java.io.Serializable;
14+
import java.util.ArrayList;
15+
import java.util.Arrays;
16+
import java.util.zip.GZIPInputStream;
17+
import java.util.zip.GZIPOutputStream;
18+
19+
import org.fnlp.ml.classifier.AbstractClassifier;
20+
import org.fnlp.ml.classifier.LabelParser.Type;
21+
import org.fnlp.ml.classifier.linear.Linear;
22+
import org.fnlp.ml.classifier.LabelParser;
23+
import org.fnlp.ml.classifier.Predict;
24+
import org.fnlp.ml.classifier.TPredict;
25+
import org.fnlp.ml.feature.FeatureSelect;
26+
import org.fnlp.ml.types.Instance;
27+
import org.fnlp.ml.types.alphabet.AlphabetFactory;
28+
import org.fnlp.ml.types.sv.HashSparseVector;
29+
import org.fnlp.nlp.pipe.Pipe;
30+
import org.fnlp.util.exception.LoadModelException;
31+
import org.junit.Ignore;
32+
/**
33+
* 朴素贝叶斯分类器
34+
* @author sywu
35+
*
36+
*/
37+
public class BayesClassifier extends AbstractClassifier implements Serializable{
38+
protected AlphabetFactory factory;
39+
protected ItemFrequency tf;
40+
protected Pipe pipe;
41+
protected FeatureSelect fs;
42+
43+
@Override
44+
public Predict classify(Instance instance, int n) {
45+
// TODO Auto-generated method stub
46+
47+
int typeSize=tf.getTypeSize();
48+
float[] score=new float[typeSize];
49+
Arrays.fill(score, 0.0f);
50+
51+
Object obj=instance.getData();
52+
if(!(obj instanceof HashSparseVector)){
53+
System.out.println("error 输入类型非HashSparseVector!");
54+
return null;
55+
}
56+
HashSparseVector data = (HashSparseVector) obj;
57+
if(fs!=null)
58+
data=fs.select(data);
59+
TIntFloatIterator it = data.data.iterator();
60+
float feaSize=tf.getFeatureSize();
61+
while (it.hasNext()) {
62+
it.advance();
63+
if(it.key()==0)
64+
continue;
65+
int feature=it.key();
66+
for(int type=0;type<typeSize;type++){
67+
float itemF=tf.getItemFrequency(feature, type);
68+
float typeF=tf.getTypeFrequency(type);
69+
score[type]+=it.value()*Math.log((itemF+0.1)/(typeF+feaSize));
70+
}
71+
}
72+
73+
Predict<Integer> res=new Predict<Integer>(n);
74+
for(int type=0;type<typeSize;type++)
75+
res.add(type, score[type]);
76+
77+
return res;
78+
}
79+
80+
@Override
81+
public Predict classify(Instance instance, Type type, int n) {
82+
// TODO Auto-generated method stub
83+
Predict res = (Predict) classify(instance, n);
84+
return LabelParser.parse(res,factory.DefaultLabelAlphabet(),type);
85+
}
86+
/**
87+
* 得到类标签
88+
* @param idx 类标签对应的索引
89+
* @return
90+
*/
91+
public String getLabel(int idx) {
92+
return factory.DefaultLabelAlphabet().lookupString(idx);
93+
}
94+
95+
/**
96+
* 将分类器保存到文件
97+
* @param file
98+
* @throws IOException
99+
*/
100+
public void saveTo(String file) throws IOException {
101+
File f = new File(file);
102+
File path = f.getParentFile();
103+
if(!path.exists()){
104+
path.mkdirs();
105+
}
106+
107+
ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream(
108+
new BufferedOutputStream(new FileOutputStream(file))));
109+
out.writeObject(this);
110+
out.close();
111+
}
112+
/**
113+
* 从文件读入分类器
114+
* @param file
115+
* @return
116+
* @throws LoadModelException
117+
*/
118+
public static BayesClassifier loadFrom(String file) throws LoadModelException{
119+
BayesClassifier cl = null;
120+
try {
121+
ObjectInputStream in = new ObjectInputStream(new GZIPInputStream(
122+
new BufferedInputStream(new FileInputStream(file))));
123+
cl = (BayesClassifier) in.readObject();
124+
in.close();
125+
} catch (Exception e) {
126+
throw new LoadModelException(e,file);
127+
}
128+
return cl;
129+
}
130+
public void fS_CS(float percent){featureSelectionChiSquare(percent);}
131+
public void featureSelectionChiSquare(float percent){
132+
fs=new FeatureSelect(tf.getFeatureSize());
133+
fs.fS_CS(tf, percent);
134+
}
135+
public void fS_CS_Max(float percent){featureSelectionChiSquareMax(percent);}
136+
public void featureSelectionChiSquareMax(float percent){
137+
fs=new FeatureSelect(tf.getFeatureSize());
138+
fs.fS_CS_Max(tf, percent);
139+
}
140+
public void fS_IG(float percent){featureSelectionInformationGain(percent);}
141+
public void featureSelectionInformationGain(float percent){
142+
fs=new FeatureSelect(tf.getFeatureSize());
143+
fs.fS_IG(tf, percent);
144+
}
145+
public void noFeatureSelection(){
146+
fs=null;
147+
}
148+
public ItemFrequency getTf() {
149+
return tf;
150+
}
151+
152+
public void setTf(ItemFrequency tf) {
153+
this.tf = tf;
154+
}
155+
public Pipe getPipe() {
156+
return pipe;
157+
}
158+
159+
public void setPipe(Pipe pipe) {
160+
this.pipe = pipe;
161+
}
162+
163+
public void setFactory(AlphabetFactory factory){
164+
this.factory=factory;
165+
}
166+
public AlphabetFactory getFactory(){
167+
return factory;
168+
}
169+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package org.fnlp.ml.classifier.bayes;
2+
3+
import gnu.trove.iterator.TIntFloatIterator;
4+
5+
import java.util.List;
6+
7+
import org.fnlp.ml.classifier.AbstractClassifier;
8+
import org.fnlp.ml.classifier.linear.AbstractTrainer;
9+
import org.fnlp.ml.types.Instance;
10+
import org.fnlp.ml.types.InstanceSet;
11+
import org.fnlp.ml.types.alphabet.AlphabetFactory;
12+
import org.fnlp.ml.types.sv.HashSparseVector;
13+
import org.fnlp.nlp.pipe.Pipe;
14+
import org.fnlp.nlp.pipe.SeriesPipes;
15+
/**
16+
* 贝叶斯文本分类模型训练器
17+
* 输入训练数据为稀疏矩阵
18+
* @author sywu
19+
*
20+
*/
21+
public class BayesTrainer{
22+
23+
public AbstractClassifier train(InstanceSet trainset) {
24+
AlphabetFactory af=trainset.getAlphabetFactory();
25+
SeriesPipes pp=(SeriesPipes) trainset.getPipes();
26+
pp.removeTargetPipe();
27+
return train(trainset,af,pp);
28+
}
29+
public AbstractClassifier train(InstanceSet trainset,AlphabetFactory af,Pipe pp) {
30+
ItemFrequency tf=new ItemFrequency(trainset,af);
31+
BayesClassifier classifier=new BayesClassifier();
32+
classifier.setFactory(af);
33+
classifier.setPipe(pp);
34+
classifier.setTf(tf);
35+
return classifier;
36+
}
37+
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package org.fnlp.ml.classifier.bayes;
2+
3+
import java.util.ArrayList;
4+
/**
5+
* 堆
6+
* @author sywu
7+
*
8+
* @param <T> 存储的数据类型
9+
*/
10+
public class Heap<T>{
11+
private boolean isMinRootHeap;
12+
private ArrayList<T> datas;
13+
private double[] scores;
14+
private int maxsize;
15+
private int size;
16+
17+
public Heap(int max,boolean isMinRootHeap) {
18+
this.isMinRootHeap=isMinRootHeap;
19+
maxsize = max;
20+
scores = new double[maxsize+1];
21+
datas= new ArrayList<T>();
22+
size = 0;
23+
datas.add(null);
24+
scores[0]=0;
25+
26+
}
27+
public Heap(int max) {
28+
this(max,true);
29+
}
30+
31+
32+
private int leftchild(int pos) {
33+
return 2 * pos;
34+
}
35+
36+
private int rightchild(int pos) {
37+
return 2 * pos + 1;
38+
}
39+
40+
private int parent(int pos) {
41+
return pos / 2;
42+
}
43+
44+
private boolean isleaf(int pos) {
45+
return ((pos > size / 2) && (pos <= size));
46+
}
47+
48+
private boolean needSwapWithParent(int pos){
49+
return isMinRootHeap?
50+
scores[pos] < scores[parent(pos)]:
51+
scores[pos] > scores[parent(pos)];
52+
}
53+
54+
private void swap(int pos1, int pos2) {
55+
double tmp;
56+
tmp = scores[pos1];
57+
scores[pos1] = scores[pos2];
58+
scores[pos2] = tmp;
59+
T t1,t2;
60+
t1=datas.get(pos1);
61+
t2=datas.get(pos2);
62+
datas.set(pos1, t2);
63+
datas.set(pos2, t1);
64+
}
65+
66+
67+
public void insert(double score,T data) {
68+
if(size<maxsize){
69+
size++;
70+
scores[size] = score;
71+
datas.add(data);
72+
int current = size;
73+
while (current!=1&&needSwapWithParent(current)) {
74+
swap(current, parent(current));
75+
current = parent(current);
76+
}
77+
}
78+
else {
79+
if(isMinRootHeap?
80+
score>scores[1]:
81+
score<scores[1]){
82+
scores[1]=score;
83+
datas.set(1, data);
84+
pushdown(1);
85+
}
86+
}
87+
}
88+
89+
90+
public void print() {
91+
int i;
92+
for (i = 1; i <= size; i++)
93+
System.out.println(scores[i] + " " +datas.get(i).toString());
94+
System.out.println();
95+
}
96+
97+
98+
// public int removemin() {
99+
// swap(1, size);
100+
// size--;
101+
// if (size != 0)
102+
// pushdown(1);
103+
// return score[size + 1];
104+
// }
105+
private int findcheckchild(int pos){
106+
int rlt;
107+
rlt = leftchild(pos);
108+
if(rlt==size)
109+
return rlt;
110+
if (isMinRootHeap?(scores[rlt] > scores[rlt + 1]):(scores[rlt] < scores[rlt + 1]))
111+
rlt = rlt + 1;
112+
return rlt;
113+
}
114+
115+
private void pushdown(int pos) {
116+
int checkchild;
117+
while (!isleaf(pos)) {
118+
checkchild = findcheckchild(pos);
119+
if(needSwapWithParent(checkchild))
120+
swap(pos, checkchild);
121+
else
122+
return;
123+
pos = checkchild;
124+
}
125+
}
126+
127+
public ArrayList<T> getData(){
128+
return datas;
129+
}
130+
131+
public static void main(String args[])
132+
{
133+
Heap<String> hm = new Heap<String>(6,true);
134+
hm.insert(1,"11");
135+
hm.insert(4,"44");
136+
hm.insert(2,"22");
137+
hm.insert(6,"66");
138+
hm.insert(3,"33");
139+
hm.insert(5,"55");
140+
hm.insert(9,"99");
141+
hm.insert(7,"77");
142+
hm.print();
143+
144+
}
145+
}

0 commit comments

Comments
 (0)