|
1 | 1 | use std::collections::HashMap; |
2 | 2 |
|
3 | | -use crate::translator::boundaries::{word_end, word_start}; |
| 3 | +use crate::translator::boundaries::{word_end, word_start, word_number, number_word}; |
4 | 4 |
|
5 | 5 | use super::Translation; |
6 | 6 |
|
@@ -56,6 +56,12 @@ impl TrieNode { |
56 | 56 | fn not_word_end_transition(&self) -> Option<&TrieNode> { |
57 | 57 | self.transitions.get(&Transition::End(Boundary::NotWord)) |
58 | 58 | } |
| 59 | + fn word_num_transition(&self) -> Option<&TrieNode> { |
| 60 | + self.transitions.get(&Transition::End(Boundary::WordNumber)) |
| 61 | + } |
| 62 | + fn num_word_transition(&self) -> Option<&TrieNode> { |
| 63 | + self.transitions.get(&Transition::Start(Boundary::NumberWord)) |
| 64 | + } |
59 | 65 | } |
60 | 66 |
|
61 | 67 | #[derive(Default, Debug)] |
@@ -109,81 +115,74 @@ impl Trie { |
109 | 115 | } |
110 | 116 |
|
111 | 117 | fn find_translations_from_node<'a>( |
112 | | - &'a self, |
| 118 | + &self, |
113 | 119 | input: &str, |
| 120 | + prev: Option<char>, |
114 | 121 | node: &'a TrieNode, |
115 | 122 | ) -> Vec<&'a Translation> { |
116 | | - let mut current_node = node; |
117 | 123 | let mut matching_rules = Vec::new(); |
118 | | - let mut prev: Option<char> = None; |
119 | 124 | let mut chars = input.chars(); |
120 | 125 |
|
121 | | - while let Some(c) = chars.next() { |
122 | | - if let Some(node) = current_node.char_transition(c) { |
123 | | - current_node = node; |
124 | | - if let Some(ref translation) = node.translation { |
125 | | - matching_rules.push(translation) |
126 | | - } |
127 | | - } else if let Some(node) = current_node.char_case_insensitive_transition(c) { |
128 | | - current_node = node; |
129 | | - if let Some(ref translation) = node.translation { |
130 | | - matching_rules.push(translation) |
131 | | - } |
132 | | - } else if let Some(node) = current_node.word_end_transition() { |
133 | | - current_node = node; |
134 | | - if word_end(prev, Some(c)) { |
135 | | - if let Some(ref translation) = node.translation { |
136 | | - matching_rules.push(translation) |
137 | | - } |
138 | | - } |
139 | | - } else if let Some(node) = current_node.not_word_end_transition() { |
140 | | - current_node = node; |
141 | | - if !word_end(prev, Some(c)) { |
142 | | - if let Some(ref translation) = node.translation { |
143 | | - matching_rules.push(translation) |
144 | | - } |
145 | | - } |
146 | | - } else { |
147 | | - prev = Some(c); |
148 | | - break; |
| 126 | + // if this node has a translation add it to the list of matching rules |
| 127 | + if let Some(ref translation) = node.translation { |
| 128 | + matching_rules.push(translation) |
| 129 | + } |
| 130 | + let c = chars.next(); |
| 131 | + if let Some(c) = c { |
| 132 | + let bytes = c.len_utf8(); |
| 133 | + if let Some(node) = node.char_transition(c) { |
| 134 | + matching_rules.extend(self.find_translations_from_node( |
| 135 | + &input[bytes..], |
| 136 | + Some(c), |
| 137 | + node, |
| 138 | + )); |
| 139 | + } |
| 140 | + if let Some(node) = node.char_case_insensitive_transition(c) { |
| 141 | + matching_rules.extend(self.find_translations_from_node( |
| 142 | + &input[bytes..], |
| 143 | + Some(c), |
| 144 | + node, |
| 145 | + )); |
149 | 146 | } |
150 | | - prev = Some(c); |
151 | 147 | } |
152 | | - // at this point we have either |
153 | | - // - exhausted the input (chars.next() is None) or |
154 | | - // - exhausted the trie (current_node has no applicable transitions) |
155 | | - // TODO: assert this invariant (how can we do this without the |
156 | | - // side-effecting chars.next()?) |
157 | | - if let Some(node) = current_node.word_end_transition() { |
158 | | - if word_end(prev, chars.next()) { |
159 | | - if let Some(ref translation) = node.translation { |
160 | | - matching_rules.push(translation) |
161 | | - } |
| 148 | + if let Some(node) = node.word_start_transition() { |
| 149 | + if word_start(prev, c) { |
| 150 | + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); |
162 | 151 | } |
163 | | - } else if let Some(node) = current_node.not_word_end_transition() { |
164 | | - if !word_end(prev, chars.next()) { |
165 | | - if let Some(ref translation) = node.translation { |
166 | | - matching_rules.push(translation) |
167 | | - } |
| 152 | + } |
| 153 | + if let Some(node) = node.not_word_start_transition() { |
| 154 | + if !word_start(prev, c) { |
| 155 | + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); |
| 156 | + } |
| 157 | + } |
| 158 | + if let Some(node) = node.word_end_transition() { |
| 159 | + if word_end(prev, c) { |
| 160 | + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); |
| 161 | + } |
| 162 | + } |
| 163 | + if let Some(node) = node.not_word_end_transition() { |
| 164 | + if !word_end(prev, c) { |
| 165 | + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); |
| 166 | + } |
| 167 | + } |
| 168 | + if let Some(node) = node.word_num_transition() { |
| 169 | + if word_number(prev, c) { |
| 170 | + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); |
| 171 | + } |
| 172 | + } |
| 173 | + if let Some(node) = node.num_word_transition() { |
| 174 | + dbg!(prev, c); |
| 175 | + if number_word(prev, c) { |
| 176 | + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); |
168 | 177 | } |
169 | 178 | } |
170 | 179 | matching_rules |
171 | 180 | } |
172 | 181 |
|
173 | | - pub fn find_translations(&self, input: &str, before: Option<char>) -> Vec<&Translation> { |
| 182 | + pub fn find_translations(&self, input: &str, prev: Option<char>) -> Vec<&Translation> { |
174 | 183 | let mut matching_rules = Vec::new(); |
175 | 184 |
|
176 | | - if word_start(before, input.chars().next()) { |
177 | | - if let Some(node) = self.root.word_start_transition() { |
178 | | - matching_rules = self.find_translations_from_node(input, node); |
179 | | - } |
180 | | - } else { |
181 | | - if let Some(node) = self.root.not_word_start_transition() { |
182 | | - matching_rules = self.find_translations_from_node(input, node); |
183 | | - } |
184 | | - } |
185 | | - |
186 | | - matching_rules.extend(self.find_translations_from_node(input, &self.root)); |
| 185 | + matching_rules.extend(self.find_translations_from_node(input, prev, dbg!(&self.root))); |
187 | 186 | matching_rules.sort_by_key(|translation| translation.weight); |
188 | 187 | matching_rules |
189 | 188 | } |
@@ -300,6 +299,41 @@ mod tests { |
300 | 299 | assert_eq!(trie.find_translations("foobar", Some('c')), vec![&foo]); |
301 | 300 | } |
302 | 301 |
|
| 302 | + #[test] |
| 303 | + fn find_translations_with_word_num_boundary() { |
| 304 | + let mut trie = Trie::new(); |
| 305 | + let empty = Vec::<&Translation>::new(); |
| 306 | + let foo = Translation::new("aaa".into(), "A".into(), 5); |
| 307 | + trie.insert( |
| 308 | + "aaa".into(), |
| 309 | + "A".into(), |
| 310 | + Boundary::Word, |
| 311 | + Boundary::WordNumber, |
| 312 | + ); |
| 313 | + assert_eq!(trie.find_translations("aaa", None), empty); |
| 314 | + assert_eq!(trie.find_translations("aaa1", Some(' ')), vec![&foo]); |
| 315 | + assert_eq!(trie.find_translations("aaa1", Some('.')), vec![&foo]); |
| 316 | + assert_eq!(trie.find_translations("aaa1", Some('c')), empty); |
| 317 | + } |
| 318 | + |
| 319 | + #[test] |
| 320 | + fn find_translations_with_num_word_boundary() { |
| 321 | + let mut trie = Trie::new(); |
| 322 | + let empty = Vec::<&Translation>::new(); |
| 323 | + let foo = Translation::new("st".into(), "S".into(), 4); |
| 324 | + trie.insert( |
| 325 | + "st".into(), |
| 326 | + "S".into(), |
| 327 | + Boundary::NumberWord, |
| 328 | + Boundary::Word, |
| 329 | + ); |
| 330 | + assert_eq!(trie.find_translations("st", None), empty); |
| 331 | + assert_eq!(trie.find_translations("st", Some(' ')), empty); |
| 332 | + assert_eq!(trie.find_translations("st", Some('.')), empty); |
| 333 | + assert_eq!(trie.find_translations("st", Some('1')), vec![&foo]); |
| 334 | + assert_eq!(trie.find_translations("sta", Some('1')), empty); |
| 335 | + } |
| 336 | + |
303 | 337 | #[test] |
304 | 338 | fn find_translations_case_insensitive_test() { |
305 | 339 | let mut trie = Trie::new(); |
|
0 commit comments