Skip to content

Commit 000e27b

Browse files
authored
[Feat][Spark] Align info implementation of spark with c++ (#316)
1 parent 2faccd8 commit 000e27b

File tree

18 files changed

+177
-339
lines changed

18 files changed

+177
-339
lines changed

pyspark/graphar_pyspark/info.py

Lines changed: 36 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,6 @@ def __init__(
539539
aligned_by: Optional[str],
540540
prefix: Optional[str],
541541
file_type: Optional[FileType],
542-
property_groups: Optional[Sequence[PropertyGroup]],
543542
jvm_obj: Optional[JavaObject],
544543
) -> None:
545544
"""One should not use this constructor directly, please use `from_scala` or `from_python`."""
@@ -552,9 +551,6 @@ def __init__(
552551
jvm_adj_list.setAligned_by(aligned_by)
553552
jvm_adj_list.setPrefix(prefix)
554553
jvm_adj_list.setFile_type(file_type.value)
555-
jvm_adj_list.setProperty_groups(
556-
[py_property_group.to_scala() for py_property_group in property_groups],
557-
)
558554
self._jvm_adj_list_obj = jvm_adj_list
559555

560556
def get_ordered(self) -> bool:
@@ -615,25 +611,6 @@ def set_file_type(self, file_type: FileType) -> None:
615611
"""
616612
self._jvm_adj_list_obj.setFile_type(file_type.value)
617613

618-
def get_property_groups(self) -> Sequence[PropertyGroup]:
619-
"""Get property groups from the corresponding JVM object.
620-
621-
:returns: property groups
622-
"""
623-
return [
624-
PropertyGroup.from_scala(jvm_property_group)
625-
for jvm_property_group in self._jvm_adj_list_obj.getProperty_groups()
626-
]
627-
628-
def set_property_groups(self, property_groups: Sequence[PropertyGroup]) -> None:
629-
"""Mutate the corresponding JVM object.
630-
631-
:param property_groups: new property groups
632-
"""
633-
self._jvm_adj_list_obj.setProperty_groups(
634-
[p_group.to_scala() for p_group in property_groups],
635-
)
636-
637614
def get_adj_list_type(self) -> AdjListType:
638615
"""Get adj list type.
639616
@@ -658,7 +635,7 @@ def from_scala(
658635
:param jvm_obj: scala object in JVM.
659636
:returns: instance of Python Class.
660637
"""
661-
return AdjList(None, None, None, None, None, jvm_obj)
638+
return AdjList(None, None, None, None, jvm_obj)
662639

663640
@classmethod
664641
def from_python(
@@ -667,19 +644,17 @@ def from_python(
667644
aligned_by: str,
668645
prefix: str,
669646
file_type: FileType,
670-
property_groups: Sequence[PropertyGroup],
671647
) -> AdjListClassType:
672648
"""Create an instance of the class from python arguments.
673649
674650
:param ordered: ordered flag
675651
:param aligned_by: recommended values are "src" or "dst"
676652
:param prefix: path prefix
677653
:param file_type: file type
678-
:param property_groups: sequence of PropertyGroup objects
679654
"""
680655
if not prefix.endswith(os.sep):
681656
prefix += os.sep
682-
return AdjList(ordered, aligned_by, prefix, file_type, property_groups, None)
657+
return AdjList(ordered, aligned_by, prefix, file_type, None)
683658

684659
def __eq__(self, other: object) -> bool:
685660
if not isinstance(other, AdjList):
@@ -690,14 +665,6 @@ def __eq__(self, other: object) -> bool:
690665
and (self.get_aligned_by() == other.get_aligned_by())
691666
and (self.get_prefix() == other.get_prefix())
692667
and (self.get_file_type() == other.get_file_type())
693-
and (len(self.get_property_groups()) == len(other.get_property_groups()))
694-
and all(
695-
left_pg == right_pg
696-
for left_pg, right_pg in zip(
697-
self.get_property_groups(),
698-
other.get_property_groups(),
699-
)
700-
)
701668
)
702669

703670

@@ -719,6 +686,7 @@ def __init__(
719686
directed: Optional[bool],
720687
prefix: Optional[str],
721688
adj_lists: Sequence[AdjList],
689+
property_groups: Optional[Sequence[PropertyGroup]],
722690
version: Optional[str],
723691
jvm_edge_info_obj: JavaObject,
724692
) -> None:
@@ -739,6 +707,9 @@ def __init__(
739707
edge_info.setAdj_lists(
740708
[py_adj_list.to_scala() for py_adj_list in adj_lists],
741709
)
710+
edge_info.setProperty_groups(
711+
[py_property_group.to_scala() for py_property_group in property_groups],
712+
)
742713
edge_info.setVersion(version)
743714
self._jvm_edge_info_obj = edge_info
744715

@@ -873,6 +844,27 @@ def set_adj_lists(self, adj_lists: Sequence[AdjList]) -> None:
873844
[py_adj_list.to_scala() for py_adj_list in adj_lists],
874845
)
875846

847+
def get_property_groups(self) -> Sequence[PropertyGroup]:
848+
"""Get the property groups of adj list type.
849+
850+
WARNING! Exceptions from the JVM are not checked inside, it is just a proxy-method!
851+
852+
:returns: property groups of edge info.
853+
"""
854+
return [
855+
PropertyGroup.from_scala(jvm_property_group)
856+
for jvm_property_group in self._jvm_edge_info_obj.getProperty_groups()
857+
]
858+
859+
def set_property_groups(self, property_groups: Sequence[PropertyGroup]) -> None:
860+
"""Mutate the corresponding JVM object.
861+
862+
:param property_groups: the new property groups, sequence of PropertyGroup
863+
"""
864+
self._jvm_edge_info_obj.setProperty_groups(
865+
[py_property_group.to_scala() for py_property_group in property_groups],
866+
)
867+
876868
def get_version(self) -> str:
877869
"""Get GAR version from the corresponding JVM object.
878870
@@ -912,6 +904,7 @@ def from_scala(cls: type[EdgeInfoType], jvm_obj: JavaObject) -> EdgeInfoType:
912904
None,
913905
None,
914906
None,
907+
None,
915908
jvm_obj,
916909
)
917910

@@ -927,6 +920,7 @@ def from_python(
927920
directed: bool,
928921
prefix: str,
929922
adj_lists: Sequence[AdjList],
923+
property_groups: Sequence[PropertyGroup],
930924
version: str,
931925
) -> EdgeInfoType:
932926
"""Create an instance of the class from python arguments.
@@ -940,6 +934,7 @@ def from_python(
940934
:param directed: directed graph flag
941935
:param prefix: path prefix
942936
:param adj_lists: sequence of AdjList objects
937+
:property_groups: sequence of of PropertyGroup objects
943938
:param version: version of GAR format
944939
"""
945940
if not prefix.endswith(os.sep):
@@ -955,6 +950,7 @@ def from_python(
955950
directed,
956951
prefix,
957952
adj_lists,
953+
property_groups,
958954
version,
959955
None,
960956
)
@@ -990,41 +986,18 @@ def get_adj_list_file_type(self, adj_list_type: AdjListType) -> FileType:
990986
self._jvm_edge_info_obj.getAdjListFileType(adj_list_type.to_scala()),
991987
)
992988

993-
def get_property_groups(
994-
self,
995-
adj_list_type: AdjListType,
996-
) -> Sequence[PropertyGroup]:
997-
"""Get the property groups of adj list type.
998-
999-
WARNING! Exceptions from the JVM are not checked inside, it is just a proxy-method!
1000-
1001-
:param adj_list_type: the input adj list type.
1002-
:returns: property group of the input adj list type, if edge info not support the adj list type,
1003-
raise an IllegalArgumentException error.
1004-
"""
1005-
return [
1006-
PropertyGroup.from_scala(property_group)
1007-
for property_group in self._jvm_edge_info_obj.getPropertyGroups(
1008-
adj_list_type.to_scala(),
1009-
)
1010-
]
1011-
1012989
def contain_property_group(
1013990
self,
1014991
property_group: PropertyGroup,
1015-
adj_list_type: AdjListType,
1016992
) -> bool:
1017-
"""Check if the edge info contains the property group in certain adj list structure.
993+
"""Check if the edge info contains the property group.
1018994
1019995
:param property_group: the property group to check.
1020-
:param adj_list_type: the type of adj list structure.
1021996
:returns: true if the edge info contains the property group in certain adj list
1022-
structure. If edge info not support the given adj list type or not
1023-
contains the property group in the adj list structure, return false.
997+
structure.
1024998
"""
1025999
return self._jvm_edge_info_obj.containPropertyGroup(
10261000
property_group.to_scala(),
1027-
adj_list_type.to_scala(),
10281001
)
10291002

10301003
def contain_property(self, property_name: str) -> bool:
@@ -1038,23 +1011,17 @@ def contain_property(self, property_name: str) -> bool:
10381011
def get_property_group(
10391012
self,
10401013
property_name: str,
1041-
adj_list_type: AdjListType,
10421014
) -> PropertyGroup:
10431015
"""Get property group that contains property with adj list type.
10441016
10451017
WARNING! Exceptions from the JVM are not checked inside, it is just a proxy-method!
10461018
10471019
:param property_name: name of the property.
1048-
:param adj_list_type: the type of adj list structure.
1049-
:returns: property group that contains the property. If edge info not support the
1050-
adj list type, or not find the property group that contains the property,
1051-
return false.
1020+
:returns: property group that contains the property. If edge info not find the property group that contains the property,
1021+
raise error.
10521022
"""
10531023
return PropertyGroup.from_scala(
1054-
self._jvm_edge_info_obj.getPropertyGroup(
1055-
property_name,
1056-
adj_list_type.to_scala(),
1057-
),
1024+
self._jvm_edge_info_obj.getPropertyGroup(property_name),
10581025
)
10591026

10601027
def get_property_type(self, property_name: str) -> GarType:

pyspark/tests/test_info.py

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -87,26 +87,11 @@ def test_property_group(spark):
8787
def test_adj_list(spark):
8888
initialize(spark)
8989

90-
props_list_1 = [
91-
Property.from_python("non_primary", GarType.DOUBLE, False),
92-
Property.from_python("primary", GarType.INT64, True),
93-
]
94-
95-
props_list_2 = [
96-
Property.from_python("non_primary", GarType.DOUBLE, False),
97-
Property.from_python("primary", GarType.INT64, True),
98-
Property("another_one", GarType.LIST, False),
99-
]
100-
10190
adj_list_from_py = AdjList.from_python(
10291
True,
10392
"dest",
10493
"prefix",
10594
FileType.PARQUET,
106-
[
107-
PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1),
108-
PropertyGroup.from_python("prefix2", FileType.ORC, props_list_2),
109-
],
11095
)
11196

11297
assert adj_list_from_py == AdjList.from_scala(adj_list_from_py.to_scala())
@@ -125,28 +110,6 @@ def test_adj_list(spark):
125110
adj_list_from_py.set_file_type(FileType.CSV)
126111
assert adj_list_from_py.get_file_type() == FileType.CSV
127112

128-
adj_list_from_py.set_property_groups(
129-
adj_list_from_py.get_property_groups()
130-
+ [
131-
PropertyGroup.from_python(
132-
"prefix3", FileType.CSV, props_list_1 + props_list_2
133-
)
134-
]
135-
)
136-
assert all(
137-
pg_left == pg_right
138-
for pg_left, pg_right in zip(
139-
adj_list_from_py.get_property_groups(),
140-
[
141-
PropertyGroup.from_python("prefix1", FileType.PARQUET, props_list_1),
142-
PropertyGroup.from_python("prefix2", FileType.ORC, props_list_2),
143-
PropertyGroup.from_python(
144-
"prefix3", FileType.CSV, props_list_1 + props_list_2
145-
),
146-
],
147-
)
148-
)
149-
150113

151114
def test_vertex_info(spark):
152115
initialize(spark)
@@ -293,6 +256,7 @@ def test_edge_info(spark):
293256
directed=True,
294257
prefix="prefix",
295258
adj_lists=[],
259+
property_groups=[],
296260
version="v1",
297261
)
298262

@@ -335,15 +299,19 @@ def test_edge_info(spark):
335299
"dest",
336300
"prefix",
337301
FileType.PARQUET,
338-
[
339-
PropertyGroup.from_python(
340-
"prefix1", FileType.PARQUET, props_list_1
341-
),
342-
],
343302
)
344303
]
345304
)
305+
py_edge_info.set_property_groups(
306+
[
307+
PropertyGroup.from_python(
308+
"prefix1", FileType.PARQUET, props_list_1
309+
),
310+
],
311+
)
312+
346313
assert len(py_edge_info.get_adj_lists()) == 1
314+
assert len(py_edge_info.get_property_groups()) == 1
347315

348316
# Load from YAML
349317
person_knows_person_info = EdgeInfo.load_edge_info(
@@ -387,12 +355,11 @@ def test_edge_info(spark):
387355
!= 0
388356
)
389357
assert (
390-
len(person_knows_person_info.get_property_groups(AdjListType.ORDERED_BY_SOURCE))
358+
len(person_knows_person_info.get_property_groups())
391359
== 1
392360
)
393361
assert person_knows_person_info.contain_property_group(
394-
person_knows_person_info.get_property_groups(AdjListType.UNORDERED_BY_DEST)[0],
395-
AdjListType.UNORDERED_BY_DEST,
362+
person_knows_person_info.get_property_groups()[0],
396363
)
397364
assert person_knows_person_info.get_property_type("creationDate") == GarType.STRING
398365
assert person_knows_person_info.is_primary_key("creationDate") == False
@@ -443,7 +410,7 @@ def test_edge_info(spark):
443410
assert (
444411
person_knows_person_info.get_property_file_path(
445412
person_knows_person_info.get_property_group(
446-
"creationDate", AdjListType.ORDERED_BY_SOURCE
413+
"creationDate",
447414
),
448415
AdjListType.ORDERED_BY_SOURCE,
449416
0,
@@ -454,7 +421,7 @@ def test_edge_info(spark):
454421
assert (
455422
person_knows_person_info.get_property_group_path_prefix(
456423
person_knows_person_info.get_property_group(
457-
"creationDate", AdjListType.ORDERED_BY_SOURCE
424+
"creationDate",
458425
),
459426
AdjListType.ORDERED_BY_SOURCE,
460427
0,
@@ -463,7 +430,7 @@ def test_edge_info(spark):
463430
assert (
464431
person_knows_person_info.get_property_group_path_prefix(
465432
person_knows_person_info.get_property_group(
466-
"creationDate", AdjListType.ORDERED_BY_SOURCE
433+
"creationDate",
467434
),
468435
AdjListType.ORDERED_BY_SOURCE,
469436
None,
@@ -536,6 +503,7 @@ def test_graph_info(spark):
536503
True,
537504
"prefix",
538505
[],
506+
[],
539507
"v1",
540508
)
541509
)

pyspark/tests/test_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def test_edge_reader(spark):
8383
assert (
8484
"_graphArEdgeIndex"
8585
in edge_reader.read_edge_property_group(
86-
edge_info.get_property_group("weight", AdjListType.ORDERED_BY_SOURCE)
86+
edge_info.get_property_group("weight")
8787
).columns
8888
)
8989
assert (
9090
edge_reader.read_edge_property_group(
91-
edge_info.get_property_group("weight", AdjListType.ORDERED_BY_SOURCE)
91+
edge_info.get_property_group("weight")
9292
).count()
9393
> 0
9494
)

0 commit comments

Comments
 (0)