Skip to content

Commit 668183b

Browse files
committed
v2.1 改进分类器,更新模型
1 parent 6623f35 commit 668183b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2080
-1978
lines changed

.classpath

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
<?xml version="1.0" encoding="UTF-8"?>
22
<classpath>
3-
<classpathentry kind="src" path="src/main/java"/>
43
<classpathentry kind="con" path="org.eclipse.m2e.MAVEN2_CLASSPATH_CONTAINER">
54
<attributes>
65
<attribute name="maven.pomderived" value="true"/>

fnlp-core/src/main/java/org/fnlp/data/reader/SequenceReader.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,19 @@ private Instance readSequence() {
9999
cur = null;
100100
try {
101101
ArrayList<ArrayList<String>> seq = new ArrayList<ArrayList<String>>();
102-
ArrayList<String> first = new ArrayList(); //至少有一列元素
103-
seq.add(first);
102+
ArrayList<String> firstColumnList = new ArrayList(); //至少有一列元素
103+
seq.add(firstColumnList);
104104
ArrayList<String> labels = null;
105105
if(hasTarget){
106106
labels = new ArrayList<String>();
107107
}
108108
String content = null;
109+
109110
while ((content = reader.readLine()) != null) {
110111
lineNo++;
111-
// content = content.trim();
112+
content = content.trim();
112113
if (content.matches("^$")){
113-
if(first.size()>0) //第一列个数>0
114+
if(firstColumnList.size()>0) //第一列个数>0
114115
break;
115116
else
116117
continue;
@@ -140,10 +141,17 @@ private Instance readSequence() {
140141
}else{
141142
ensure(colsnum,seq);
142143
seq.get(colsnum).add(content.substring(start));
143-
}
144+
}
145+
//debug
146+
// if(colsnum>2){
147+
// System.out.println(content);
148+
// }
144149
}
145-
if (first.size() > 0){
150+
151+
if (firstColumnList.size() > 0){
146152
cur = new Instance(seq, labels);
153+
//debug
154+
// cur.setSource(firstColumnList.toString());
147155
}
148156
seq = null;
149157
labels = null;

fnlp-core/src/main/java/org/fnlp/ml/classifier/linear/OnlineTrainer.java

Lines changed: 36 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ public class OnlineTrainer extends AbstractTrainer {
5353
*/
5454
public static float eps = 1e-10f;
5555

56-
public TrainMethod method = TrainMethod.FastAverage;
57-
5856
public boolean DEBUG = false;
5957
public boolean shuffle = true;
6058
public boolean finalOptimized = false;
@@ -75,9 +73,6 @@ public class OnlineTrainer extends AbstractTrainer {
7573
public int iternum;
7674
protected float[] weights;
7775

78-
public enum TrainMethod {
79-
Perceptron, Average, FastAverage
80-
}
8176
public OnlineTrainer(AlphabetFactory af, int iternum) {
8277
//默认特征生成器
8378
Generator gen = new SFGenerator();
@@ -167,61 +162,60 @@ public Linear train(InstanceSet trainset, InstanceSet devset) {
167162
long beginTimeIter, endTimeIter;
168163
int iter = 0;
169164
int frac = numSamples / 10;
170-
171-
float[] averageWeights = null;
172-
if (method == TrainMethod.Average || method == TrainMethod.FastAverage) {
173-
averageWeights = new float[weights.length];
174-
}
175-
165+
166+
//平均化感知器需要减去的权重
167+
float[] extraweight = null;
168+
extraweight = new float[weights.length];
169+
170+
171+
176172
beginTime = System.currentTimeMillis();
177173

178-
if (shuffle)
179-
trainset.shuffle(random);
174+
175+
176+
//遍历的总样本数
177+
int k=0;
180178

181179
while (iter++ < iternum) {
182180
if (!simpleOutput) {
183181
System.out.print("iter "+iter+": ");
184-
}
182+
}
183+
185184
float err = 0;
186185
float errtot = 0;
187186
int cnt = 0;
188187
int cnttot = 0;
189-
int progress = frac;
188+
int progress = frac;
190189

190+
if (shuffle)
191+
trainset.shuffle(random);
192+
191193
beginTimeIter = System.currentTimeMillis();
192-
193-
float[] innerWeights = null;
194-
if (method == TrainMethod.Average) {
195-
innerWeights = Arrays.copyOf(weights, weights.length);
196-
}
197-
198-
for (int ii = 0; ii < numSamples; ii++) {
194+
for (int ii = 0; ii < numSamples; ii++) {
195+
196+
k++;
199197
Instance inst = trainset.getInstance(ii);
200198
Predict pred = (Predict) inferencer.getBest(inst,2);
201199

202200
float l = loss.calc(pred.getLabel(0), inst.getTarget());
203201
if (l > 0) {
204202
err += l;
205203
errtot++;
206-
update.update(inst, weights, pred.getLabel(0), c);
204+
update.update(inst, weights, k, extraweight, pred.getLabel(0), c);
207205

208206
}else{
209207
if (pred.size() > 1)
210-
update.update(inst, weights, pred.getLabel(1), c);
208+
update.update(inst, weights, k, extraweight, pred.getLabel(1), c);
211209
}
212210
cnt += inst.length();
213-
cnttot++;
214-
if (method == TrainMethod.Average) {
215-
for (int i = 0; i < weights.length; i++) {
216-
innerWeights[i] += weights[i];
217-
}
218-
}
211+
cnttot++;
219212

220213
if (!simpleOutput && progress != 0 && ii % progress == 0) {
221214
System.out.print('.');
222215
progress += frac;
223-
}
224-
}
216+
}
217+
218+
}//end for
225219

226220
float curErrRate = err / cnt;
227221

@@ -253,17 +247,7 @@ public Linear train(InstanceSet trainset, InstanceSet devset) {
253247
if (devset != null) {
254248
evaluate(devset);
255249
}
256-
System.out.println();
257-
258-
if (method == TrainMethod.Average) {
259-
for (int i = 0; i < innerWeights.length; i++) {
260-
averageWeights[i] += innerWeights[i] / numSamples;
261-
}
262-
} else if (method == TrainMethod.FastAverage) {
263-
for (int i = 0; i < weights.length; i++) {
264-
averageWeights[i] += weights[i];
265-
}
266-
}
250+
System.out.println();
267251

268252
if (interim) {
269253
Linear p = new Linear(inferencer, trainset.getAlphabetFactory());
@@ -277,18 +261,15 @@ public Linear train(InstanceSet trainset, InstanceSet devset) {
277261
if(MyArrays.viarance(hisErrRate) < eps){
278262
System.out.println("convergence!");
279263
break;
280-
}
281-
}
282-
283-
if (method == TrainMethod.Average || method == TrainMethod.FastAverage) {
284-
for (int i = 0; i < averageWeights.length; i++) {
285-
averageWeights[i] /= iternum;
286-
}
287-
weights = null;
288-
weights = averageWeights;
289-
inferencer.setWeights(weights);
290-
}
291-
264+
}
265+
266+
}// end while 外循环
267+
268+
//平均化参数
269+
for (int i = 0; i < weights.length; i++) {
270+
weights[i] -= extraweight[i]/k;
271+
}
272+
292273
System.out.print("Non-Zero Weight Numbers: " + MyArrays.countNoneZero(weights));
293274
if (finalOptimized) {
294275
int[] idx = MyArrays.getTop(weights.clone(), threshold, false);

fnlp-core/src/main/java/org/fnlp/ml/classifier/linear/update/AbstractPAUpdate.java

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
* 抽象参数更新类,采用PA算法
2828
* \mathbf{w_{t+1}} = \w_t + {\alpha^*(\Phi(x,y)- \Phi(x,\hat{y}))}.
2929
* \alpha =\frac{1- \mathbf{w_t}^T \left(\Phi(x,y) - \Phi(x,\hat{y})\right)}{||\Phi(x,y) - \Phi(x,\hat{y})||^2}.
30-
* @author Feng Ji
3130
*
3231
*/
3332
public abstract class AbstractPAUpdate implements Update {
@@ -52,28 +51,14 @@ public AbstractPAUpdate(Loss loss) {
5251
this.loss = loss;
5352
}
5453

55-
/**
56-
* 参数更新方法
57-
* @param inst 样本实例
58-
* @param weights 权重
59-
* @param predict 预测答案
60-
* @param c 步长阈值
61-
* @return 预测答案和标准答案之间的损失
62-
*/
63-
public float update(Instance inst, float[] weights, Object predict, float c) {
64-
return update(inst, weights, inst.getTarget(), predict, c);
65-
}
54+
@Override
55+
public float update(Instance inst, float[] weights, int k, float[] extraweight, Object predict, float c) {
56+
return update(inst, weights, k, extraweight, inst.getTarget(), predict, c);
57+
}
58+
6659

67-
/**
68-
* 参数更新方法
69-
* @param inst 样本实例
70-
* @param weights 权重
71-
* @param target 对照答案
72-
* @param predict 预测答案
73-
* @param c 步长阈值
74-
* @return 预测答案和对照答案之间的损失
75-
*/
76-
public float update(Instance inst, float[] weights, Object target,
60+
@Override
61+
public float update(Instance inst, float[] weights, int k, float[] extraweight, Object target,
7762
Object predict, float c) {
7863

7964
int lost = diff(inst, weights, target, predict);
@@ -87,14 +72,17 @@ public float update(Instance inst, float[] weights, Object target,
8772
alpha = alpha*inst.getWeight();
8873
if(alpha>c){
8974
alpha = c;
90-
}else{
91-
alpha=alpha;
92-
}
75+
}
76+
9377
int[] idx = diffv.indices();
9478

95-
for (int i = 0; i < idx.length; i++) {
96-
97-
weights[idx[i]] += diffv.get(idx[i]) * alpha;
79+
for (int i = 0; i < idx.length; i++) {
80+
float t = diffv.get(idx[i]) * alpha;
81+
weights[idx[i]] += t;
82+
extraweight[idx[i]] += t *k;
83+
}
84+
for (int i = 0; i < idx.length; i++) {
85+
9886
}
9987
}
10088

@@ -105,12 +93,12 @@ public float update(Instance inst, float[] weights, Object target,
10593
}
10694

10795
/**
108-
* 计算预测答案和对照答案之间的距离
96+
* 计算预测类别和对照类别之间的距离
10997
* @param inst 样本实例
11098
* @param weights 权重
111-
* @param target 对照答案
112-
* @param predict 预测答案
113-
* @return 预测答案和对照答案之间的距离
99+
* @param target 对照类别
100+
* @param predict 预测类别
101+
* @return 预测类别和对照类别之间的距离
114102
*/
115103
protected abstract int diff(Instance inst, float[] weights, Object target,
116104
Object predict);

fnlp-core/src/main/java/org/fnlp/ml/classifier/linear/update/LinearMaxPAUpdate.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424

2525
/**
2626
* 线性分类的参数更新类,采用PA算法
27-
* @author Feng Ji
28-
*
2927
*/
3028
public class LinearMaxPAUpdate extends AbstractPAUpdate {
3129

@@ -58,5 +56,4 @@ protected int diff(Instance inst, float[] weights, Object target,
5856
return 1;
5957
}
6058

61-
6259
}

fnlp-core/src/main/java/org/fnlp/ml/classifier/linear/update/Update.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,32 @@
2222
import org.fnlp.ml.types.Instance;
2323

2424
public interface Update {
25-
26-
public float update(Instance inst, float[] weights, Object predictLabel,
27-
float c);
28-
29-
public float update(Instance inst, float[] weights, Object predictLabel,
25+
26+
/**
27+
*
28+
* @param inst 样本实例
29+
* @param weights 权重
30+
* @param k 目前遍历的样本数
31+
* @param extraweight 平均化感知器需要减去的权重
32+
* @param predictLabel 预测类别
33+
* @param c 步长阈值
34+
* @return 预测类别和真实类别之间的损失
35+
*/
36+
public float update(Instance inst, float[] weights, int k, float[] extraweight, Object predictLabel,
37+
float c);
38+
39+
/**
40+
*
41+
* @param inst 样本实例
42+
* @param weights 权重
43+
* @param k 目前遍历的样本数
44+
* @param extraweight 平均化感知器需要减去的权重
45+
* @param predictLabel 预测类别
46+
* @param goldenLabel 真实类别
47+
* @param c 步长阈值
48+
* @return 预测类别和真实类别之间的损失
49+
*/
50+
public float update(Instance inst, float[] weights, int k, float[] extraweight, Object predictLabel,
3051
Object goldenLabel, float c);
3152

3253
}

0 commit comments

Comments
 (0)