Skip to content

Commit 9474c09

Browse files
authored
Merge pull request #2 from thewh1teagle/feat/rename-dirty-to-lexicon
rename dirty to lexicon
2 parents d53749e + b6bc6f1 commit 9474c09

File tree

2 files changed

+64
-65
lines changed

2 files changed

+64
-65
lines changed

cmd/train_phonemizer2/main.go

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,33 @@
11
package main
22

33
import (
4+
"bytes"
45
"encoding/json"
6+
"flag"
7+
"fmt"
58
"io/ioutil"
6-
"bytes"
7-
)
9+
"math/rand"
10+
"os"
11+
"runtime"
12+
"sync/atomic"
13+
"time"
814

9-
import "os"
10-
import "sync/atomic"
11-
import "fmt"
12-
import "runtime"
13-
import "flag"
14-
import "math/rand"
15-
import "time"
15+
"github.com/neurlang/classifier/datasets"
16+
"github.com/neurlang/classifier/datasets/phonemizer"
17+
"github.com/neurlang/classifier/hashtron"
18+
"github.com/neurlang/classifier/layer/crossattention"
19+
"github.com/neurlang/classifier/layer/sochastic"
20+
"github.com/neurlang/classifier/layer/sum"
21+
"github.com/neurlang/classifier/net/feedforward"
22+
"github.com/neurlang/classifier/parallel"
23+
"github.com/neurlang/quaternary"
24+
)
1625

17-
import "github.com/neurlang/classifier/datasets/phonemizer"
1826
//import "github.com/neurlang/classifier/layer/majpool2d"
19-
import "github.com/neurlang/classifier/layer/sum"
20-
import "github.com/neurlang/classifier/layer/sochastic"
27+
2128
//import "github.com/neurlang/classifier/layer/parity"
22-
import "github.com/neurlang/classifier/layer/crossattention"
23-
import "github.com/neurlang/classifier/datasets"
24-
import "github.com/neurlang/classifier/hashtron"
29+
2530
//import "github.com/neurlang/classifier/learning"
26-
import "github.com/neurlang/quaternary"
27-
import "github.com/neurlang/classifier/net/feedforward"
28-
import "github.com/neurlang/classifier/parallel"
2931

3032
func error_abs(a, b uint32) (out uint32) {
3133
xor := a ^ b
@@ -62,7 +64,7 @@ func write_histogram(langjson string, histogram []string) {
6264
fmt.Println("Error marshalling JSON:", err)
6365
return
6466
}
65-
67+
6668
updatedData = bytes.ReplaceAll(updatedData, []byte(`"],"`), []byte("\"],\n\""))
6769

6870
// Step 5: Write the updated JSON back to the file
@@ -75,7 +77,7 @@ func write_histogram(langjson string, histogram []string) {
7577
}
7678

7779
func main() {
78-
dirtytsv := flag.String("dirtytsv", "", "dirty tsv dataset for the language")
80+
lexicontsv := flag.String("lexicontsv", "", "lexicon tsv dataset for the language")
7981
learntsv := flag.String("learntsv", "", "learn tsv dataset for the language")
8082
langjson := flag.String("langjson", "", "language.json for the language to write histogram")
8183
premodulo := flag.Int("premodulo", 0, "premodulo")
@@ -92,7 +94,7 @@ func main() {
9294

9395
var improved_success_rate = 0
9496

95-
if dirtytsv == nil || *dirtytsv == "" {
97+
if lexicontsv == nil || *lexicontsv == "" {
9698
println("clean tsv is mandatory")
9799
return
98100
}
@@ -106,15 +108,15 @@ func main() {
106108
}
107109

108110
histogram := phonemizer.NewHistogram(*learntsv, reverse != nil && *reverse)
109-
111+
110112
if langjson != nil && *langjson != "" {
111113
write_histogram(*langjson, histogram)
112114
}
113-
115+
114116
fmt.Println(histogram)
115117

116-
data := phonemizer.SplitAreg(phonemizer.NewDatasetAreg(*learntsv, *dirtytsv, reverse != nil && *reverse, histogram))
117-
118+
data := phonemizer.SplitAreg(phonemizer.NewDatasetAreg(*learntsv, *lexicontsv, reverse != nil && *reverse, histogram))
119+
118120
if len(data) == 0 {
119121
println("it looks like no data for this language, or language is unambiguous (no model needed)")
120122
return
@@ -123,7 +125,7 @@ func main() {
123125
const fanout1 = 16
124126
const fanout2 = 2
125127
const fanout3 = 3
126-
128+
127129
var net feedforward.FeedforwardNetwork
128130
net.NewLayer(fanout1*fanout2, 0)
129131
for i := 0; i < fanout3; i++ {
@@ -134,10 +136,9 @@ func main() {
134136
}
135137
net.NewCombiner(sochastic.MustNew(fanout1*fanout2, 32, fanout3))
136138
net.NewLayer(fanout1*fanout2, 0)
137-
net.NewCombiner(sum.MustNew([]uint{fanout1*fanout2}, 0))
139+
net.NewCombiner(sum.MustNew([]uint{fanout1 * fanout2}, 0))
138140
net.NewLayer(1, 0)
139141

140-
141142
trainWorst := func(worst int) func() {
142143
var tally = new(datasets.Tally)
143144
tally.Init()
@@ -148,7 +149,7 @@ func main() {
148149
if minpremodulo != nil && *minpremodulo > 0 && maxpremodulo != nil && *maxpremodulo > 0 {
149150
const span = 50 * 50
150151
value := (100 - improved_success_rate) * (100 - improved_success_rate)
151-
premodulo := value * ( *minpremodulo - *maxpremodulo ) / span + *maxpremodulo
152+
premodulo := value*(*minpremodulo-*maxpremodulo)/span + *maxpremodulo
152153
//println(improved_success_rate, premodulo)
153154
if premodulo < 2 {
154155
premodulo = 2
@@ -161,17 +162,17 @@ func main() {
161162
rand.Shuffle(len(data), func(i, j int) { data[i], data[j] = data[j], data[i] })
162163
parts = *part
163164
}
164-
165+
165166
parallel.ForEach(len(data)/parts, 1000, func(jjj int) {
166167
{
167-
var io = data[jjj]
168-
169-
io.Dimension = fanout1
168+
var io = data[jjj]
170169

171-
net.Tally4(&io, worst, tally, nil)
170+
io.Dimension = fanout1
171+
172+
net.Tally4(&io, worst, tally, nil)
172173
}
173174
})
174-
175+
175176
if !tally.GetImprovementPossible() {
176177
return nil
177178
}
@@ -194,7 +195,7 @@ func main() {
194195
tally.Free()
195196
runtime.GC()
196197

197-
return func(){
198+
return func() {
198199
*ptr = backup
199200
}
200201
}
@@ -212,17 +213,17 @@ func main() {
212213
io.Dimension = fanout1
213214

214215
var predicted = net.Infer2(&io) & 1
215-
216+
216217
h.MustPutUint16(j, predicted)
217-
218+
218219
if predicted == io.Output() {
219220
percent.Add(1)
220221
}
221222
errsum.Add(uint64(error_abs(uint32(predicted), uint32(io.Output()))))
222223
}
223224
})
224-
success := 100 * int(percent.Load()) / (len(data)/parts)
225-
println("[success rate]", success, "%", "with", uint64(parts) * errsum.Load(), "errors")
225+
success := 100 * int(percent.Load()) / (len(data) / parts)
226+
println("[success rate]", success, "%", "with", uint64(parts)*errsum.Load(), "errors")
226227

227228
if dstmodel == nil || *dstmodel == "" {
228229
err := net.WriteZlibWeightsToFile("output." + fmt.Sprint(success) + ".json.t.lzw")

datasets/phonemizer_multi/multi.go

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
package phonemizer_multi
22

3-
import "github.com/jbarham/primegen"
4-
import (
5-
"github.com/neurlang/classifier/hash"
6-
"encoding/json"
7-
"sort"
8-
"strconv"
9-
)
103
import (
114
"bufio"
5+
"encoding/json"
126
"fmt"
137
"os"
8+
"sort"
9+
"strconv"
1410
"strings"
15-
//"encoding/json"
11+
12+
"github.com/jbarham/primegen"
13+
"github.com/neurlang/classifier/hash"
1614
)
1715

16+
//"encoding/json"
17+
1818
var Primes []uint32
1919

2020
func init() {
@@ -45,15 +45,15 @@ func (t *Token) Len() int {
4545

4646
func (s *Sample) V1(dim, pos int) SampleSentence {
4747
return SampleSentence{
48-
Sample: s,
49-
position: pos,
48+
Sample: s,
49+
position: pos,
5050
dimension: dim,
5151
}
5252
}
5353

5454
type SampleSentence struct {
55-
Sample *Sample
56-
position int
55+
Sample *Sample
56+
position int
5757
dimension int
5858
}
5959

@@ -66,13 +66,13 @@ func (s *SampleSentence) Len() int {
6666

6767
type SampleSentenceIO struct {
6868
SampleSentence *SampleSentence
69-
choice int
69+
choice int
7070
}
7171

7272
func (s *SampleSentence) IO(n int) (ret *SampleSentenceIO) {
7373
return &SampleSentenceIO{
7474
SampleSentence: s,
75-
choice: n,
75+
choice: n,
7676
}
7777
}
7878

@@ -85,14 +85,14 @@ func (s *SampleSentenceIO) Feature(n int) (ret uint32) {
8585
if s.Parity() == 1 {
8686
ret = 1 << 31
8787
}
88-
if n % 3 == 0 {
89-
for ; pos < len((s.SampleSentence.Sample.Sentence)); pos += (s.SampleSentence.dimension/3) {
88+
if n%3 == 0 {
89+
for ; pos < len((s.SampleSentence.Sample.Sentence)); pos += (s.SampleSentence.dimension / 3) {
9090
ret += uint32(s.SampleSentence.Sample.Sentence[pos].Homograph)
9191
}
9292
return
9393

9494
}
95-
for ; pos < len((s.SampleSentence.Sample.Sentence)); pos += (s.SampleSentence.dimension/3) {
95+
for ; pos < len((s.SampleSentence.Sample.Sentence)); pos += (s.SampleSentence.dimension / 3) {
9696
if pos < s.SampleSentence.position {
9797
ret += uint32(s.SampleSentence.Sample.Sentence[pos].Solution)
9898
} else if pos == s.SampleSentence.position {
@@ -113,7 +113,7 @@ func (s *SampleSentenceIO) Parity() (ret uint16) {
113113
return uint16(len(s.SampleSentence.Sample.Sentence) & 1)
114114
}
115115
func (s *SampleSentenceIO) Output() (ret uint16) {
116-
if (s.SampleSentence.Sample.Sentence[s.SampleSentence.position].Choices[s.choice][0] == s.SampleSentence.Sample.Sentence[s.SampleSentence.position].Solution) {
116+
if s.SampleSentence.Sample.Sentence[s.SampleSentence.position].Choices[s.choice][0] == s.SampleSentence.Sample.Sentence[s.SampleSentence.position].Solution {
117117
return 1
118118
}
119119
return 0
@@ -159,7 +159,6 @@ func loop(filename string, do func(string, string, string)) {
159159
}
160160
}
161161

162-
163162
func addTags(bag map[uint32]string, tags ...string) map[uint32]string {
164163
for _, v := range tags {
165164
bag[hash.StringHash(0, v)] = v
@@ -202,13 +201,12 @@ func serializeTags(tags map[uint32]string) (key uint32, ret string) {
202201
return
203202
}
204203

205-
206204
func NewDataset(dir string) (ret []Sample) {
207205

208206
var tags = make(map[uint32]string)
209207
var m = make(map[string]map[string]uint32)
210208

211-
loop(dir + string(os.PathSeparator) + "dirty.tsv", func(src string, dst, tag string) {
209+
loop(dir+string(os.PathSeparator)+"lexicon.tsv", func(src string, dst, tag string) {
212210
if _, ok := m[src]; !ok {
213211
m[src] = make(map[string]uint32)
214212
}
@@ -232,7 +230,7 @@ func NewDataset(dir string) (ret []Sample) {
232230
}
233231
})
234232

235-
loop(dir + string(os.PathSeparator) + "multi.tsv", func(src string, dst, _ string) {
233+
loop(dir+string(os.PathSeparator)+"multi.tsv", func(src string, dst, _ string) {
236234
srcv := strings.Split(src, " ")
237235
dstv := strings.Split(dst, " ")
238236
if len(srcv) != len(dstv) {
@@ -249,7 +247,7 @@ func NewDataset(dir string) (ret []Sample) {
249247
fmt.Println("ERROR: Word not in dict:", srcv[i], dstv[i])
250248
t := Token{
251249
Homograph: hash.StringHash(0, srcv[i]),
252-
Solution: 0,
250+
Solution: 0,
253251
}
254252
s.Sentence = append(s.Sentence, t)
255253
continue
@@ -284,8 +282,8 @@ func NewDataset(dir string) (ret []Sample) {
284282
}
285283
t := Token{
286284
Homograph: hash.StringHash(0, srcv[i]),
287-
Solution: sol,
288-
Choices: array,
285+
Solution: sol,
286+
Choices: array,
289287
}
290288
s.Sentence = append(s.Sentence, t)
291289
}

0 commit comments

Comments
 (0)