1
1
package main
2
2
3
3
import (
4
+ "bytes"
4
5
"encoding/json"
6
+ "flag"
7
+ "fmt"
5
8
"io/ioutil"
6
- "bytes"
7
- )
9
+ "math/rand"
10
+ "os"
11
+ "runtime"
12
+ "sync/atomic"
13
+ "time"
8
14
9
- import "os"
10
- import "sync/atomic"
11
- import "fmt"
12
- import "runtime"
13
- import "flag"
14
- import "math/rand"
15
- import "time"
15
+ "github.com/neurlang/classifier/datasets"
16
+ "github.com/neurlang/classifier/datasets/phonemizer"
17
+ "github.com/neurlang/classifier/hashtron"
18
+ "github.com/neurlang/classifier/layer/crossattention"
19
+ "github.com/neurlang/classifier/layer/sochastic"
20
+ "github.com/neurlang/classifier/layer/sum"
21
+ "github.com/neurlang/classifier/net/feedforward"
22
+ "github.com/neurlang/classifier/parallel"
23
+ "github.com/neurlang/quaternary"
24
+ )
16
25
17
- import "github.com/neurlang/classifier/datasets/phonemizer"
18
26
//import "github.com/neurlang/classifier/layer/majpool2d"
19
- import "github.com/neurlang/classifier/layer/sum"
20
- import "github.com/neurlang/classifier/layer/sochastic"
27
+
21
28
//import "github.com/neurlang/classifier/layer/parity"
22
- import "github.com/neurlang/classifier/layer/crossattention"
23
- import "github.com/neurlang/classifier/datasets"
24
- import "github.com/neurlang/classifier/hashtron"
29
+
25
30
//import "github.com/neurlang/classifier/learning"
26
- import "github.com/neurlang/quaternary"
27
- import "github.com/neurlang/classifier/net/feedforward"
28
- import "github.com/neurlang/classifier/parallel"
29
31
30
32
func error_abs (a , b uint32 ) (out uint32 ) {
31
33
xor := a ^ b
@@ -62,7 +64,7 @@ func write_histogram(langjson string, histogram []string) {
62
64
fmt .Println ("Error marshalling JSON:" , err )
63
65
return
64
66
}
65
-
67
+
66
68
updatedData = bytes .ReplaceAll (updatedData , []byte (`"],"` ), []byte ("\" ],\n \" " ))
67
69
68
70
// Step 5: Write the updated JSON back to the file
@@ -75,7 +77,7 @@ func write_histogram(langjson string, histogram []string) {
75
77
}
76
78
77
79
func main () {
78
- dirtytsv := flag .String ("dirtytsv " , "" , "dirty tsv dataset for the language" )
80
+ lexicontsv := flag .String ("lexicontsv " , "" , "lexicon tsv dataset for the language" )
79
81
learntsv := flag .String ("learntsv" , "" , "learn tsv dataset for the language" )
80
82
langjson := flag .String ("langjson" , "" , "language.json for the language to write histogram" )
81
83
premodulo := flag .Int ("premodulo" , 0 , "premodulo" )
@@ -92,7 +94,7 @@ func main() {
92
94
93
95
var improved_success_rate = 0
94
96
95
- if dirtytsv == nil || * dirtytsv == "" {
97
+ if lexicontsv == nil || * lexicontsv == "" {
96
98
println ("clean tsv is mandatory" )
97
99
return
98
100
}
@@ -106,15 +108,15 @@ func main() {
106
108
}
107
109
108
110
histogram := phonemizer .NewHistogram (* learntsv , reverse != nil && * reverse )
109
-
111
+
110
112
if langjson != nil && * langjson != "" {
111
113
write_histogram (* langjson , histogram )
112
114
}
113
-
115
+
114
116
fmt .Println (histogram )
115
117
116
- data := phonemizer .SplitAreg (phonemizer .NewDatasetAreg (* learntsv , * dirtytsv , reverse != nil && * reverse , histogram ))
117
-
118
+ data := phonemizer .SplitAreg (phonemizer .NewDatasetAreg (* learntsv , * lexicontsv , reverse != nil && * reverse , histogram ))
119
+
118
120
if len (data ) == 0 {
119
121
println ("it looks like no data for this language, or language is unambiguous (no model needed)" )
120
122
return
@@ -123,7 +125,7 @@ func main() {
123
125
const fanout1 = 16
124
126
const fanout2 = 2
125
127
const fanout3 = 3
126
-
128
+
127
129
var net feedforward.FeedforwardNetwork
128
130
net .NewLayer (fanout1 * fanout2 , 0 )
129
131
for i := 0 ; i < fanout3 ; i ++ {
@@ -134,10 +136,9 @@ func main() {
134
136
}
135
137
net .NewCombiner (sochastic .MustNew (fanout1 * fanout2 , 32 , fanout3 ))
136
138
net .NewLayer (fanout1 * fanout2 , 0 )
137
- net .NewCombiner (sum .MustNew ([]uint {fanout1 * fanout2 }, 0 ))
139
+ net .NewCombiner (sum .MustNew ([]uint {fanout1 * fanout2 }, 0 ))
138
140
net .NewLayer (1 , 0 )
139
141
140
-
141
142
trainWorst := func (worst int ) func () {
142
143
var tally = new (datasets.Tally )
143
144
tally .Init ()
@@ -148,7 +149,7 @@ func main() {
148
149
if minpremodulo != nil && * minpremodulo > 0 && maxpremodulo != nil && * maxpremodulo > 0 {
149
150
const span = 50 * 50
150
151
value := (100 - improved_success_rate ) * (100 - improved_success_rate )
151
- premodulo := value * ( * minpremodulo - * maxpremodulo ) / span + * maxpremodulo
152
+ premodulo := value * ( * minpremodulo - * maxpremodulo ) / span + * maxpremodulo
152
153
//println(improved_success_rate, premodulo)
153
154
if premodulo < 2 {
154
155
premodulo = 2
@@ -161,17 +162,17 @@ func main() {
161
162
rand .Shuffle (len (data ), func (i , j int ) { data [i ], data [j ] = data [j ], data [i ] })
162
163
parts = * part
163
164
}
164
-
165
+
165
166
parallel .ForEach (len (data )/ parts , 1000 , func (jjj int ) {
166
167
{
167
- var io = data [jjj ]
168
-
169
- io .Dimension = fanout1
168
+ var io = data [jjj ]
170
169
171
- net .Tally4 (& io , worst , tally , nil )
170
+ io .Dimension = fanout1
171
+
172
+ net .Tally4 (& io , worst , tally , nil )
172
173
}
173
174
})
174
-
175
+
175
176
if ! tally .GetImprovementPossible () {
176
177
return nil
177
178
}
@@ -194,7 +195,7 @@ func main() {
194
195
tally .Free ()
195
196
runtime .GC ()
196
197
197
- return func (){
198
+ return func () {
198
199
* ptr = backup
199
200
}
200
201
}
@@ -212,17 +213,17 @@ func main() {
212
213
io .Dimension = fanout1
213
214
214
215
var predicted = net .Infer2 (& io ) & 1
215
-
216
+
216
217
h .MustPutUint16 (j , predicted )
217
-
218
+
218
219
if predicted == io .Output () {
219
220
percent .Add (1 )
220
221
}
221
222
errsum .Add (uint64 (error_abs (uint32 (predicted ), uint32 (io .Output ()))))
222
223
}
223
224
})
224
- success := 100 * int (percent .Load ()) / (len (data )/ parts )
225
- println ("[success rate]" , success , "%" , "with" , uint64 (parts ) * errsum .Load (), "errors" )
225
+ success := 100 * int (percent .Load ()) / (len (data ) / parts )
226
+ println ("[success rate]" , success , "%" , "with" , uint64 (parts )* errsum .Load (), "errors" )
226
227
227
228
if dstmodel == nil || * dstmodel == "" {
228
229
err := net .WriteZlibWeightsToFile ("output." + fmt .Sprint (success ) + ".json.t.lzw" )
0 commit comments