Skip to content
This repository was archived by the owner on Mar 9, 2023. It is now read-only.

Commit 0734f9b

Browse files
authored
Merge pull request #163 from WorksApplications/feature/fix_join_katakana_oov
Fix issue #162
2 parents 9eae063 + 8c4a6ef commit 0734f9b

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

sudachipy/lattice.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ cdef class Lattice:
5757
return self.end_lists[end]
5858

5959
def get_nodes(self, begin: int, end: int) -> List[LatticeNode]:
60-
return [node for node in self.end_lists[end] if node.begin == begin]
60+
return [node for node in self.end_lists[end] if node.get_begin() == begin]
6161

62-
def get_minumum_node(self, begin: int, end: int) -> Optional[LatticeNode]:
62+
def get_minimum_node(self, begin: int, end: int) -> Optional[LatticeNode]:
6363
nodes = self.get_nodes(begin, end)
6464
if not nodes:
6565
return None
6666
min_arg = nodes[0]
6767
for node in nodes[1:]:
68-
if node.cost < min_arg.cost:
68+
if node.get_path_cost() < min_arg.get_path_cost():
6969
min_arg = node
7070
return min_arg
7171

sudachipy/plugin/path_rewrite/path_rewrite_plugin.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ def concatenate_oov(self, path, begin, end, pos_id, lattice):
6262
raise IndexError("begin >= end")
6363
b = path[begin].get_begin()
6464
e = path[end - 1].get_end()
65+
66+
n = lattice.get_minimum_node(b, e)
67+
if n is not None:
68+
path[begin:end] = [n]
69+
return n
70+
6571
surface = ""
6672
length = 0
6773
for i in range(begin, end):
@@ -76,6 +82,7 @@ def concatenate_oov(self, path, begin, end, pos_id, lattice):
7682
node = lattice.create_node()
7783
node.set_range(b, e)
7884
node.set_word_info(wi)
85+
node.set_oov()
7986

8087
path[begin:end] = [node]
8188
return node

tests/plugin/test_join_katakana_oov_plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def test_pos(self):
5353
path = self.get_path('アイアイウ')
5454
self.assertEqual(1, len(path))
5555
self.assertFalse(path[0].is_oov())
56+
self.assertEqual(['名詞', '固有名詞', '地名', '一般', '*', '*'],
57+
self.dict_.grammar.get_part_of_speech_string(path[0].get_word_info().pos_id))
5658

5759
def test_starts_with_middle(self):
5860
self.plugin._min_length = 3

0 commit comments

Comments
 (0)