1
+ package org .ansj .app .crf ;
2
+
3
+ import java .io .BufferedInputStream ;
4
+ import java .io .BufferedOutputStream ;
5
+ import java .io .FileInputStream ;
6
+ import java .io .FileNotFoundException ;
7
+ import java .io .FileOutputStream ;
8
+ import java .io .IOException ;
9
+ import java .io .InputStream ;
10
+ import java .io .ObjectInputStream ;
11
+ import java .io .ObjectOutputStream ;
12
+ import java .util .ArrayList ;
13
+ import java .util .List ;
14
+ import java .util .Map ;
15
+ import java .util .Map .Entry ;
16
+ import java .util .zip .GZIPInputStream ;
17
+ import java .util .zip .GZIPOutputStream ;
18
+
19
+ import love .cq .domain .SmartForest ;
20
+
21
+ import org .ansj .app .crf .pojo .Element ;
22
+ import org .ansj .app .crf .pojo .Feature ;
23
+ import org .ansj .app .crf .pojo .Template ;
24
+
25
+ public abstract class Model {
26
+
27
+ public static enum MODEL_TYPE {
28
+ CRF , EMM
29
+ };
30
+
31
+ protected Template template = null ;
32
+
33
+ protected double [][] status = null ;
34
+
35
+ protected Map <String , Feature > myGrad ;
36
+
37
+ protected SmartForest <double [][]> smartForest = null ;
38
+
39
+ public int allFeatureCount = 0 ;
40
+
41
+ private List <Element > leftList = null ;
42
+
43
+ private List <Element > rightList = null ;
44
+
45
+ public int end1 ;
46
+
47
+ public int end2 ;
48
+
49
+ /**
50
+ * 根据模板文件解析特征
51
+ *
52
+ * @param template
53
+ * @throws IOException
54
+ */
55
+ private void makeSide (int left , int right ) throws IOException {
56
+ // TODO Auto-generated method stub
57
+
58
+ leftList = new ArrayList <Element >(Math .abs (left ));
59
+ for (int i = left ; i < 0 ; i ++) {
60
+ leftList .add (new Element ((char ) ('B' + i )));
61
+ }
62
+
63
+ rightList = new ArrayList <Element >(right );
64
+ for (int i = 1 ; i < right + 1 ; i ++) {
65
+ rightList .add (new Element ((char ) ('B' + i )));
66
+ }
67
+ }
68
+
69
+ /**
70
+ * 讲模型写入
71
+ *
72
+ * @param path
73
+ * @throws FileNotFoundException
74
+ * @throws IOException
75
+ */
76
+ public void writeModel (String path ) throws FileNotFoundException , IOException {
77
+ // TODO Auto-generated method stub
78
+
79
+ System .out .println ("compute ok now to save model!" );
80
+ // 写模型
81
+ ObjectOutputStream oos = new ObjectOutputStream (new BufferedOutputStream (new GZIPOutputStream (new FileOutputStream (path ))));
82
+
83
+ // 配置模板
84
+ oos .writeObject (template );
85
+ // 特征转移率
86
+ oos .writeObject (status );
87
+ // 总共的特征数
88
+ oos .writeInt (myGrad .size ());
89
+ double [] ds = null ;
90
+ for (Entry <String , Feature > entry : myGrad .entrySet ()) {
91
+ oos .writeUTF (entry .getKey ());
92
+ for (int i = 0 ; i < template .ft .length ; i ++) {
93
+ ds = entry .getValue ().w [i ];
94
+ for (int j = 0 ; j < ds .length ; j ++) {
95
+ oos .writeByte (j );
96
+ oos .writeFloat ((float ) ds [j ]);
97
+ }
98
+ oos .writeByte (-1 );
99
+ }
100
+ }
101
+
102
+ oos .flush ();
103
+ oos .close ();
104
+
105
+ }
106
+
107
+ /**
108
+ * 模型读取
109
+ *
110
+ * @param path
111
+ * @return
112
+ * @return
113
+ * @throws FileNotFoundException
114
+ * @throws IOException
115
+ * @throws ClassNotFoundException
116
+ */
117
+ public static Model loadModel (String modelPath ) throws Exception {
118
+ return loadModel (new FileInputStream (modelPath ));
119
+
120
+ }
121
+
122
+ public static Model loadModel (InputStream modelStream ) throws Exception {
123
+ ObjectInputStream ois = null ;
124
+ try {
125
+ ois = new ObjectInputStream (new BufferedInputStream (new GZIPInputStream (modelStream )));
126
+
127
+ Model model = new Model () {
128
+
129
+ @ Override
130
+ public void writeModel (String path ) throws FileNotFoundException , IOException {
131
+ // TODO Auto-generated method stub
132
+ throw new RuntimeException ("you can not to calculate ,this model only use by cut " );
133
+ }
134
+
135
+ };
136
+
137
+ model .template = (Template ) ois .readObject ();
138
+
139
+ model .makeSide (model .template .left , model .template .right );
140
+
141
+ int tagNum = model .template .tagNum ;
142
+
143
+ int featureNum = model .template .ft .length ;
144
+
145
+ model .smartForest = new SmartForest <double [][]>(0.8 );
146
+
147
+ model .status = (double [][]) ois .readObject ();
148
+
149
+ // 总共的特征数
150
+ double [][] w = null ;
151
+ String key = null ;
152
+ int b = 0 ;
153
+ int featureCount = ois .readInt ();
154
+ for (int i = 0 ; i < featureCount ; i ++) {
155
+ key = ois .readUTF ();
156
+ w = new double [featureNum ][0 ];
157
+ for (int j = 0 ; j < featureNum ; j ++) {
158
+ while ((b = ois .readByte ()) != -1 ) {
159
+ if (w [j ].length == 0 ) {
160
+ w [j ] = new double [tagNum ];
161
+ }
162
+ w [j ][b ] = ois .readFloat ();
163
+ }
164
+ }
165
+ model .smartForest .add (key , w );
166
+ }
167
+
168
+ return model ;
169
+ } finally {
170
+ if (ois != null ) {
171
+ ois .close ();
172
+ }
173
+ }
174
+ }
175
+
176
+ public double [] getFeature (int featureIndex , char ... chars ) {
177
+ // TODO Auto-generated method stub
178
+ SmartForest <double [][]> sf = smartForest ;
179
+ sf = sf .getBranch (chars );
180
+ if (sf == null || sf .getParam () == null ) {
181
+ return null ;
182
+ }
183
+ return sf .getParam ()[featureIndex ];
184
+ }
185
+
186
+ /**
187
+ * tag转移率
188
+ *
189
+ * @param s1
190
+ * @param s2
191
+ * @return
192
+ */
193
+ public double tagRate (int s1 , int s2 ) {
194
+ // TODO Auto-generated method stub
195
+ return status [s1 ][s2 ];
196
+ }
197
+
198
+ }
0 commit comments