diff --git a/pydatastructs/strings/tests/test_trie.py b/pydatastructs/strings/tests/test_trie.py index 059104708..10667e249 100644 --- a/pydatastructs/strings/tests/test_trie.py +++ b/pydatastructs/strings/tests/test_trie.py @@ -47,3 +47,42 @@ def test_Trie(): for j in range(i + 1): assert trie_1.is_inserted(prefix_strings_1[j]) assert trie_1.is_present(prefix_strings_1[j]) + + assert trie_1.count_words() == 3 + + assert trie_1.longest_common_prefix() == "dict" + + assert trie_1.autocomplete("dict") == ["dict", "dicts", "dicts_lists_tuples"] + + trie_2 = Trie() + trie_2.insert("apple") + trie_2.insert("app") + trie_2.insert("apricot") + trie_2.insert("banana") + assert trie_2.count_words() == 4 + + trie_2.clear() + assert trie_2.count_words() == 0 + + assert trie_2.is_empty() + + trie_3 = Trie() + trie_3.insert("hello") + trie_3.insert("world") + assert sorted(trie_3.all_words()) == ["hello", "world"] + + trie_4 = Trie() + trie_4.insert("zebra") + trie_4.insert("dog") + trie_4.insert("duck") + trie_4.insert("dove") + assert trie_4.shortest_unique_prefix() == { + "zebra": "z", + "dog": "dog", + "duck": "du", + "dove": "dov" + } + assert trie_4.starts_with("do") + assert not trie_4.starts_with("cat") + + assert trie_4.longest_word() == "zebra" diff --git a/pydatastructs/strings/trie.py b/pydatastructs/strings/trie.py index cdf6666cf..7e06746d1 100644 --- a/pydatastructs/strings/trie.py +++ b/pydatastructs/strings/trie.py @@ -49,7 +49,10 @@ class Trie(object): @classmethod def methods(cls): return ['__new__', 'insert', 'is_present', 'delete', - 'strings_with_prefix'] + 'strings_with_prefix', 'count_words', 'longest_common_prefix', + 'autocomplete', 'bulk_insert', 'clear', 'is_empty', + 'all_words', 'shortest_unique_prefix', 'starts_with', + 'longest_word'] def __new__(cls, **kwargs): raise_if_backend_is_not_python( @@ -176,26 +179,177 @@ def strings_with_prefix(self, string: str) -> list: The list of strings with the given prefix. """ - def _collect(prefix: str, node: TrieNode, strings: list) -> str: - TrieNode_stack = Stack() - TrieNode_stack.append((node, prefix)) - while TrieNode_stack: - walk, curr_prefix = TrieNode_stack.pop() - if walk.is_terminal: - strings.append(curr_prefix + walk.char) - for child in walk._children: - TrieNode_stack.append((walk.get_child(child), curr_prefix + walk.char)) + def _collect(node: TrieNode, prefix: str, strings: list): + stack = [(node, prefix)] + while stack: + current_node, current_prefix = stack.pop() + if current_node.is_terminal: + strings.append(current_prefix) + for child in current_node._children: + stack.append((current_node.get_child(child), current_prefix + child)) strings = [] - prefix = "" walk = self.root for char in string: - walk = walk.get_child(char) - if walk is None: + if walk.get_child(char) is None: return strings - prefix += char - if walk.is_terminal: - strings.append(walk.char) - for child in walk._children: - _collect(prefix, walk.get_child(child), strings) + walk = walk.get_child(char) + _collect(walk, string, strings) return strings + + def count_words(self) -> int: + """ + Returns the total number of words inserted into the trie. + + Returns + ======= + + count: int + The total number of words in the trie. + """ + def _count(node: TrieNode) -> int: + count = 0 + if node.is_terminal: + count += 1 + for child in node._children: + count += _count(node.get_child(child)) + return count + + return _count(self.root) + + def longest_common_prefix(self) -> str: + """ + Finds the longest common prefix among all the words in the trie. + + Returns + ======= + + prefix: str + The longest common prefix. + """ + prefix = "" + walk = self.root + while len(walk._children) == 1 and not walk.is_terminal: + char = next(iter(walk._children)) + prefix += char + walk = walk.get_child(char) + return prefix + + def autocomplete(self, prefix: str) -> list: + """ + Provides autocomplete suggestions based on the given prefix. + + Parameters + ========== + + prefix: str + + Returns + ======= + + suggestions: list + A list of autocomplete suggestions. + """ + return self.strings_with_prefix(prefix) + + def clear(self) -> None: + """ + Resets the Trie by replacing the root node with a new empty node. + This will effectively clears all words from the Trie. + + Returns + ======= + + None + """ + self.root = TrieNode() + + def is_empty(self) -> bool: + """ + Checks if the trie is empty. + + Returns + ======= + + bool + True if the trie is empty, False otherwise. + """ + return not self.root._children + + def all_words(self) -> list: + """ + Retrieves all words stored in the trie. + + Returns + ======= + + words: list + A list of all words in the trie. + """ + return self.strings_with_prefix("") + + def shortest_unique_prefix(self) -> dict: + """ + Finds the shortest unique prefix for each word in the trie. + + Returns + ======= + prefixes: dict + A dictionary where keys are words and values are their shortest unique prefixes. + """ + def _find_prefix(node: TrieNode, prefix: str, prefixes: dict, word: str = ""): + if node.is_terminal: + prefixes[word] = prefix # Store full word as key + for child in node._children: + new_word = word + child # Build full word + new_prefix = prefix + child + if len(node._children) > 1 or node.is_terminal: + _find_prefix(node.get_child(child), new_prefix, prefixes, new_word) + else: + _find_prefix(node.get_child(child), prefix, prefixes, new_word) + + prefixes = {} + _find_prefix(self.root, "", prefixes) + return prefixes + + + def starts_with(self, prefix: str) -> bool: + """ + Checks if any word in the trie starts with the given prefix. + + Parameters + ========== + + prefix: str + + Returns + ======= + + bool + True if any word starts with the prefix, False otherwise. + """ + walk = self.root + for char in prefix: + if walk.get_child(char) is None: + return False + walk = walk.get_child(char) + return True + + def longest_word(self) -> str: + """ + Finds the longest word stored in the trie. + + Returns + ======= + + word: str + The longest word in the trie. + """ + def _longest(node: TrieNode, current_word: str, longest_word: str) -> str: + if node.is_terminal and len(current_word) > len(longest_word): + longest_word = current_word + for child in node._children: + longest_word = _longest(node.get_child(child), current_word + child, longest_word) + return longest_word + + return _longest(self.root, "", "")