Skip to content

Commit 7e20098

Browse files
authored
Relaxed rules around subgraph operations; prevented in-place ops (#925)
* Relaxed rules around subgraph operations; prevented in-place ops * Fixed to_subgraph methods
1 parent 166a131 commit 7e20098

File tree

3 files changed

+114
-10
lines changed

3 files changed

+114
-10
lines changed

py2neo/cypher/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def to_subgraph(self):
307307
if s is None:
308308
s = s_
309309
else:
310-
s |= s_
310+
s = s | s_
311311
return s
312312

313313
def to_ndarray(self, dtype=None, order='K'):
@@ -564,7 +564,7 @@ def to_subgraph(self):
564564
if s is None:
565565
s = value
566566
else:
567-
s |= value
567+
s = s | value
568568
return s
569569

570570

py2neo/data.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def __init__(self, nodes=None, relationships=None):
109109
self.__nodes = frozenset(nodes or [])
110110
self.__relationships = frozenset(relationships or [])
111111
self.__nodes |= frozenset(chain.from_iterable(r.nodes for r in self.__relationships))
112-
if not self.__nodes:
113-
raise ValueError("Subgraphs must contain at least one node")
112+
#if not self.__nodes:
113+
# raise ValueError("Subgraphs must contain at least one node")
114114

115115
def __repr__(self):
116116
return "Subgraph({%s}, {%s})" % (", ".join(map(repr, self.nodes)),
@@ -583,6 +583,18 @@ def __or__(self, other):
583583
# use the Walkable implementation.
584584
return Walkable.__or__(self, other)
585585

586+
def __ior__(self, other):
587+
raise TypeError("In-place union is not permitted for %s objects" % self.__class__.__name__)
588+
589+
def __iand__(self, other):
590+
raise TypeError("In-place intersection is not permitted for %s objects" % self.__class__.__name__)
591+
592+
def __isub__(self, other):
593+
raise TypeError("In-place difference is not permitted for %s objects" % self.__class__.__name__)
594+
595+
def __ixor__(self, other):
596+
raise TypeError("In-place symmetric difference is not permitted for %s objects" % self.__class__.__name__)
597+
586598
@property
587599
def graph(self):
588600
return self._graph

test/unit/test_data.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from io import StringIO
2020
from unittest import TestCase
2121

22+
from _pytest.python_api import raises
23+
2224
from py2neo.cypher import Record
2325
from py2neo.data import Subgraph, Walkable, Node, Relationship, Path, walk
2426
from py2neo.integration import Table
@@ -405,8 +407,9 @@ def test_property_keys(self):
405407
assert self.subgraph.keys() == {"name", "age", "since"}
406408

407409
def test_empty_subgraph(self):
408-
with self.assertRaises(ValueError):
409-
Subgraph()
410+
s = Subgraph()
411+
assert len(s.nodes) == 0
412+
assert len(s.relationships) == 0
410413

411414

412415
class WalkableTestCase(TestCase):
@@ -869,40 +872,119 @@ def test_can_concatenate_node_and_none(self):
869872

870873
class UnionTestCase(TestCase):
871874

872-
def test_graph_union(self):
875+
def test_node_union(self):
876+
s = alice | bob
877+
assert len(s.nodes) == 2
878+
assert len(s.relationships) == 0
879+
880+
def test_node_union_in_place(self):
881+
n = Node()
882+
with raises(TypeError):
883+
n |= alice
884+
885+
def test_subgraph_union(self):
873886
graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
874887
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
875888
graph = graph_1 | graph_2
876889
assert len(graph.nodes) == 4
877890
assert len(graph.relationships) == 5
878891
assert graph.nodes == (alice | bob | carol | dave).nodes
879892

893+
def test_subgraph_union_in_place(self):
894+
graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
895+
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
896+
graph |= graph_2
897+
assert len(graph.nodes) == 4
898+
assert len(graph.relationships) == 5
899+
assert graph.nodes == (alice | bob | carol | dave).nodes
900+
880901

881902
class IntersectionTestCase(TestCase):
882903

883-
def test_graph_intersection(self):
904+
def test_node_intersection_same(self):
905+
s = alice & alice
906+
assert len(s.nodes) == 1
907+
assert len(s.relationships) == 0
908+
909+
def test_node_intersection_different(self):
910+
s = alice & bob
911+
assert len(s.nodes) == 0
912+
assert len(s.relationships) == 0
913+
914+
def test_node_intersection_in_place(self):
915+
n = Node()
916+
with raises(TypeError):
917+
n &= alice
918+
919+
def test_subgraph_intersection(self):
884920
graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
885921
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
886922
graph = graph_1 & graph_2
887923
assert len(graph.nodes) == 2
888924
assert len(graph.relationships) == 1
889925
assert graph.nodes == (bob | carol).nodes
890926

927+
def test_subgraph_intersection_in_place(self):
928+
graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
929+
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
930+
graph &= graph_2
931+
assert len(graph.nodes) == 2
932+
assert len(graph.relationships) == 1
933+
assert graph.nodes == (bob | carol).nodes
934+
891935

892936
class DifferenceTestCase(TestCase):
893937

894-
def test_graph_difference(self):
938+
def test_node_difference_same(self):
939+
s = alice - alice
940+
assert len(s.nodes) == 0
941+
assert len(s.relationships) == 0
942+
943+
def test_node_difference_different(self):
944+
s = alice - bob
945+
assert len(s.nodes) == 1
946+
assert len(s.relationships) == 0
947+
948+
def test_node_difference_in_place(self):
949+
n = Node()
950+
with raises(TypeError):
951+
n -= alice
952+
953+
def test_subgraph_difference(self):
895954
graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
896955
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
897956
graph = graph_1 - graph_2
898957
assert len(graph.nodes) == 3
899958
assert len(graph.relationships) == 2
900959
assert graph.nodes == (alice | bob | carol).nodes
901960

961+
def test_subgraph_difference_in_place(self):
962+
graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
963+
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
964+
graph -= graph_2
965+
assert len(graph.nodes) == 3
966+
assert len(graph.relationships) == 2
967+
assert graph.nodes == (alice | bob | carol).nodes
968+
902969

903970
class SymmetricDifferenceTestCase(TestCase):
904971

905-
def test_graph_symmetric_difference(self):
972+
def test_node_symmetric_difference_same(self):
973+
s = alice ^ alice
974+
assert len(s.nodes) == 0
975+
assert len(s.relationships) == 0
976+
977+
def test_node_symmetric_difference_different(self):
978+
s = alice ^ bob
979+
assert len(s.nodes) == 2
980+
assert len(s.relationships) == 0
981+
982+
def test_node_symmetric_difference_in_place(self):
983+
n = Node()
984+
with raises(TypeError):
985+
n ^= alice
986+
987+
def test_subgraph_symmetric_difference(self):
906988
graph_1 = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
907989
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
908990
graph = graph_1 ^ graph_2
@@ -912,6 +994,16 @@ def test_graph_symmetric_difference(self):
912994
assert graph.relationships == frozenset(alice_knows_bob | alice_likes_carol |
913995
carol_married_to_dave | dave_works_for_dave)
914996

997+
def test_subgraph_symmetric_difference_in_place(self):
998+
graph = (alice_knows_bob | alice_likes_carol | carol_dislikes_bob)
999+
graph_2 = (carol_dislikes_bob | carol_married_to_dave | dave_works_for_dave)
1000+
graph ^= graph_2
1001+
assert len(graph.nodes) == 4
1002+
assert len(graph.relationships) == 4
1003+
assert graph.nodes == (alice | bob | carol | dave).nodes
1004+
assert graph.relationships == frozenset(alice_knows_bob | alice_likes_carol |
1005+
carol_married_to_dave | dave_works_for_dave)
1006+
9151007

9161008
def test_record_repr():
9171009
person = Record(["name", "age"], ["Alice", 33])

0 commit comments

Comments
 (0)