33# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) 
44# Apache 2.0 
55
6+ # This program converts a transcript file `text` to labels 
7+ # used in CTC training. 
8+ # 
9+ # For example, if we have 
10+ # 
11+ # the lexicon file `lexicon.txt` 
12+ # 
13+ # foo f o o 
14+ # bar b a r 
15+ # 
16+ # the phone symbol table `tokens.txt` 
17+ # 
18+ # <eps> 0 
19+ # <blk> 1 
20+ # a 2 
21+ # b 3 
22+ # f 4 
23+ # o 5 
24+ # r 6 
25+ # 
26+ # and the transcript file `text` 
27+ # 
28+ # utt1 foo bar bar 
29+ # utt2 bar 
30+ # 
31+ # Given the above three inputs, this program generates a 
32+ # file `labels.ark` containing 
33+ # 
34+ # utt1 3 4 4 2 1 5 2 1 5 
35+ # utt2 2 1 5 
36+ # 
37+ # where 
38+ # - `3 4 4` is from `(4-1) (5-1) (5-1)`, which is from the indices of `f o o` 
39+ # - `2 1 5` is from `(3-1) (2-1) (6-1)`, which is from the indices of `b a r` 
40+ # 
41+ # Note that 1 is subtracted from here since `<eps>` exists only in FSTs 
42+ # and the neural network considers index `0` as `<blk>`, Therefore, the integer 
43+ # value of every symbol is shifted downwards by 1. 
44+ 
645import  argparse 
746import  os 
847
948import  kaldi 
1049
1150
1251def  get_args ():
13-     parser  =  argparse .ArgumentParser (description = 'convert text to labels' )
52+     parser  =  argparse .ArgumentParser (description = ''' 
53+ Convert transcript to labels. 
54+ 
55+ It takes the following inputs: 
56+ 
57+ - lexicon.txt, the lexicon file 
58+ - tokens.txt, the phone symbol table 
59+ - dir, a directory containing the transcript file `text` 
60+ 
61+ It generates `lables.scp` and `labels.ark` in the provided `dir`. 
62+ 
63+ Usage: 
64+     python3 ./local/convert_text_to_labels.py \  
65+              --lexicon-filename data/lang/lexicon.txt \ 
66+              --tokens-filename data/lang/tokens.txt \ 
67+              --dir data/train
68+ 
69+     It will generates data/train/labels.scp and data/train/labels.ark. 
70+         ''' )
71+ 
72+     parser .add_argument ('--lexicon-filename' ,
73+                         dest = 'lexicon_filename' ,
74+                         type = str ,
75+                         help = 'filename for lexicon.txt' )
76+ 
77+     parser .add_argument ('--tokens-filename' ,
78+                         dest = 'tokens_filename' ,
79+                         type = str ,
80+                         help = 'filename for the phone symbol table tokens.txt' )
1481
15-     parser .add_argument ('--lexicon-filename' , dest = 'lexicon_filename' , type = str )
16-     parser .add_argument ('--tokens-filename' , dest = 'tokens_filename' , type = str )
17-     parser .add_argument ('--dir' , help = 'input/output dir' , type = str )
82+     parser .add_argument ('--dir' ,
83+                         help = 'input/output dir' ,
84+                         type = str ,
85+                         help = '''the dir containing the transcript text; 
86+         it will contain the generated labels.scp and labels.ark''' )
1887
1988    args  =  parser .parse_args ()
2089
@@ -26,14 +95,33 @@ def get_args():
2695
2796
2897def  read_lexicon (filename ):
29-     ''' 
98+     '''Read lexicon.txt and save it into a Python dict. 
99+ 
100+     Args: 
101+         filename: filename of lexicon.txt. 
102+ 
103+                   Every line in lexicon.txt has the following format: 
104+ 
105+                     word phone1 phone2 phone3 ... phoneN 
106+ 
107+                   That is, fields are separated by spaces. The first 
108+                   field is the word and the remaining fields are the 
109+                   phones indicating the pronunciation of the word. 
110+ 
30111    Returns: 
31112        a dict whose keys are words and values are phones. 
32113    ''' 
33114    lexicon  =  dict ()
115+ 
34116    with  open (filename , 'r' , encoding = 'utf-8' ) as  f :
35117        for  line  in  f :
118+             # line contains: 
119+             # word phone1 phone2 phone3 ... phoneN 
36120            word_phones  =  line .split ()
121+ 
122+             # It should have at least two fields: 
123+             # the first one is the word and 
124+             # the second one is the pronunciation 
37125            assert  len (word_phones ) >=  2 
38126
39127            word  =  word_phones [0 ]
@@ -48,23 +136,43 @@ def read_lexicon(filename):
48136
49137
50138def  read_tokens (filename ):
51-     ''' 
139+     '''Read phone symbol table tokens.txt and save it into a Python dict. 
140+ 
141+     Note that we remove the symbol `<eps>` and shift every symbol index 
142+     downwards by 1. 
143+ 
144+     Args: 
145+         filename: filename of the phone symbol table tokens.txt. 
146+ 
147+                   Two integer values have specific meanings in the symbol 
148+                   table. The first one is 0, which is reserved for `<eps>`. 
149+                   And the second one is 1, which is reserved for the 
150+                   blank symbol `<blk>`. 
151+                   Other integer values do NOT have specific meanings. 
152+ 
52153    Returns: 
53154        a dict whose keys are phones and values are phone indices 
54155    ''' 
55156    tokens  =  dict ()
56157    with  open (filename , 'r' , encoding = 'utf-8' ) as  f :
57158        for  line  in  f :
159+             # line has the format: phone index 
58160            phone_index  =  line .split ()
161+ 
162+             # it should have two fields: 
163+             # the first field is the phone 
164+             # and the second field is its index 
59165            assert  len (phone_index ) ==  2 
60166
61167            phone  =  phone_index [0 ]
62168            index  =  int (phone_index [1 ])
63169
64170            if  phone  ==  '<eps>' :
171+                 # <eps> appears only in the FSTs. 
65172                continue 
66173
67174            # decreased by one since we removed <eps> above 
175+             # and every symbol index is shifted downwards by 1 
68176            index  -=  1 
69177
70178            assert  phone  not  in   tokens 
@@ -82,27 +190,45 @@ def read_tokens(filename):
82190
83191
84192def  read_text (filename ):
85-     ''' 
193+     '''Read transcript file `text` and save it into a Python dict. 
194+ 
195+     Args: 
196+         filename: filename of the transcript file `text`. 
197+ 
86198    Returns: 
87199        a dict whose keys are utterance IDs and values are texts 
88200    ''' 
89201    transcript  =  dict ()
90202
91203    with  open (filename , 'r' , encoding = 'utf-8' ) as  f :
92204        for  line  in  f :
93-             utt_text   =   line . split () 
94-             assert   len ( utt_text )  >=   2 
205+             #  line has the format: uttid word1 word2 word3 ... wordN 
206+             uttid_text   =   line . split () 
95207
96-             utt  =  utt_text [0 ]
97-             text  =  utt_text [1 :]
208+             # it should have at least 2 fields: 
209+             # the first field is the utterance id; 
210+             # the remaining fields are the words of the utterance 
211+             assert  len (uttid_text ) >=  2 
98212
99-             assert  utt  not  in   transcript 
100-             transcript [utt ] =  text 
213+             uttid  =  uttid_text [0 ]
214+             text  =  uttid_text [1 :]
215+ 
216+             assert  uttid  not  in   transcript 
217+             transcript [uttid ] =  text 
101218
102219    return  transcript 
103220
104221
105222def  phones_to_indices (phone_list , tokens ):
223+     '''Convert a list of phones to a list of indices via a phone symbol table. 
224+ 
225+     Args: 
226+         phone_list: a list of phones 
227+         tokens: a dict representing a phone symbol table. 
228+ 
229+     Returns: 
230+         Return a list of indices corresponding to the given phones 
231+     ''' 
106232    index_list  =  []
107233
108234    for  phone  in  phone_list :
@@ -125,7 +251,7 @@ def main():
125251
126252    transcript_labels  =  dict ()
127253
128-     for  utt , text  in  transcript .items ():
254+     for  uttid , text  in  transcript .items ():
129255        labels  =  []
130256        for  t  in  text :
131257            # TODO(fangjun): add support for OOV. 
@@ -135,17 +261,17 @@ def main():
135261
136262            labels .extend (indices )
137263
138-         assert  utt  not  in   transcript_labels 
264+         assert  uttid  not  in   transcript_labels 
139265
140-         transcript_labels [utt ] =  labels 
266+         transcript_labels [uttid ] =  labels 
141267
142268    wspecifier  =  'ark,scp:{dir}/labels.ark,{dir}/labels.scp' .format (
143269        dir = args .dir )
144270
145271    writer  =  kaldi .IntVectorWriter (wspecifier )
146272
147-     for  utt , labels  in  transcript_labels .items ():
148-         writer .Write (utt , labels )
273+     for  uttid , labels  in  transcript_labels .items ():
274+         writer .Write (uttid , labels )
149275
150276    writer .Close ()
151277
0 commit comments