diff --git a/pydatastructs/graphs/adjacency_matrix.py b/pydatastructs/graphs/adjacency_matrix.py index 48e0d3489..a8764de4c 100644 --- a/pydatastructs/graphs/adjacency_matrix.py +++ b/pydatastructs/graphs/adjacency_matrix.py @@ -57,12 +57,47 @@ def neighbors(self, node): return neighbors def add_vertex(self, node): - raise NotImplementedError("Currently we allow " - "adjacency matrix for static graphs only") + if node.name in self.matrix: + raise ValueError("Vertex %s already exists in the graph." % node.name) + self.vertices.append(node.name) + setattr(self, node.name, node) + self.matrix[node.name] = {} def remove_vertex(self, node): - raise NotImplementedError("Currently we allow " - "adjacency matrix for static graphs only.") + node = str(node) + if node not in self.matrix: + raise ValueError("Vertex '%s' is not present in the graph." % node) + + # first we need to remove the edges involving the `node` + + # removing records from dict while iterating over them is tricky + # so we'll first identify which edges to remove first + + edges_to_remove = [] + + for target in self.matrix[node]: + if self.matrix[node].get(target, False): + edges_to_remove.append((node, target)) + + for source in self.vertices: + if self.matrix[source].get(node): + edges_to_remove.append((source, node)) + + # remove the identified edge weights + for source, target in edges_to_remove: + edge_key = str(source) + "_" + str(target) + self.edge_weights.pop(edge_key) + + self.vertices.remove(node) + # eliminate all outgoing edges + self.matrix.pop(node, None) + + # eliminate all incoming edges + for source in self.vertices: + self.matrix[source].pop(node, None) + + if hasattr(self, node): + delattr(self, node) def add_edge(self, source, target, cost=None): source, target = str(source), str(target) diff --git a/pydatastructs/graphs/tests/test_adjacency_matrix.py b/pydatastructs/graphs/tests/test_adjacency_matrix.py index 2dace4260..7d6e1e4fd 100644 --- a/pydatastructs/graphs/tests/test_adjacency_matrix.py +++ b/pydatastructs/graphs/tests/test_adjacency_matrix.py @@ -30,3 +30,23 @@ def test_AdjacencyMatrix(): assert raises(ValueError, lambda: g.add_edge('v', 'x')) assert raises(ValueError, lambda: g.add_edge(2, 3)) assert raises(ValueError, lambda: g.add_edge(3, 2)) + assert g.num_vertices() == 3 + v_3 = AdjacencyMatrixGraphNode(3, 3) + g.add_vertex(v_3) + assert g.num_vertices() == 4 + g.add_edge(3, 1, 0) + g.add_edge(3, 2, 0) + g.add_edge(2, 3, 0) + assert g.is_adjacent(3, 1) is True + assert g.is_adjacent(0, 2) is False + assert g.is_adjacent(1, 3) is False + assert g.is_adjacent(2, 3) is True + assert g.is_adjacent(3, 2) is True + neighbors = g.neighbors(3) + assert neighbors == [v_1, v_2] + neighbors = g.neighbors(2) + assert neighbors == [v_0, v_3] + g.remove_vertex(3) + neighbors = g.neighbors(2) + assert neighbors == [v_0] + assert g.num_vertices() == 3