Skip to content

Commit 6a55af4

Browse files
committed
Changes so that the run method is not required anymore
1 parent 8d9c97b commit 6a55af4

File tree

4 files changed

+23
-20
lines changed

4 files changed

+23
-20
lines changed

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import abc
1817
import asyncio
1918
import enum
2019
import json
@@ -115,7 +114,7 @@ def fix_invalid_json(raw_json: str) -> str:
115114
return repaired_json
116115

117116

118-
class EntityRelationExtractor(Component, abc.ABC):
117+
class EntityRelationExtractor(Component):
119118
"""Abstract class for entity relation extraction components.
120119
121120
Args:
@@ -133,15 +132,14 @@ def __init__(
133132
self.on_error = on_error
134133
self.create_lexical_graph = create_lexical_graph
135134

136-
@abc.abstractmethod
137135
async def run(
138136
self,
139137
chunks: TextChunks,
140138
document_info: Optional[DocumentInfo] = None,
141139
lexical_graph_config: Optional[LexicalGraphConfig] = None,
142140
**kwargs: Any,
143141
) -> Neo4jGraph:
144-
pass
142+
raise NotImplementedError()
145143

146144
def update_ids(
147145
self,

src/neo4j_graphrag/experimental/components/resolver.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import abc
1615
from typing import Any, Optional
1716

1817
import neo4j
@@ -22,7 +21,7 @@
2221
from neo4j_graphrag.utils import driver_config
2322

2423

25-
class EntityResolver(Component, abc.ABC):
24+
class EntityResolver(Component):
2625
"""Entity resolution base class
2726
2827
Args:
@@ -38,9 +37,8 @@ def __init__(
3837
self.driver = driver_config.override_user_agent(driver)
3938
self.filter_query = filter_query
4039

41-
@abc.abstractmethod
4240
async def run(self, *args: Any, **kwargs: Any) -> ResolutionStats:
43-
pass
41+
raise NotImplementedError()
4442

4543

4644
class SinglePropertyExactMatchResolver(EntityResolver):

src/neo4j_graphrag/experimental/pipeline/component.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import abc
1817
import inspect
1918
from typing import Any, get_type_hints
2019

@@ -31,18 +30,14 @@ class DataModel(BaseModel):
3130
pass
3231

3332

34-
class ComponentMeta(abc.ABCMeta):
33+
class ComponentMeta(type):
3534
def __new__(
3635
meta, name: str, bases: tuple[type, ...], attrs: dict[str, Any]
3736
) -> type:
3837
# extract required inputs and outputs from the run method signature
3938
run_method = attrs.get("run")
4039
run_context_method = attrs.get("run_with_context")
4140
run = run_context_method or run_method
42-
if run is None:
43-
raise RuntimeError(
44-
f"Either 'run' or 'run_with_context' must be implemented in component: '{name}'"
45-
)
4641
sig = inspect.signature(run)
4742
attrs["component_inputs"] = {
4843
param.name: {
@@ -73,7 +68,7 @@ def __new__(
7368
return type.__new__(meta, name, bases, attrs)
7469

7570

76-
class Component(abc.ABC, metaclass=ComponentMeta):
71+
class Component(metaclass=ComponentMeta):
7772
"""Interface that needs to be implemented
7873
by all components.
7974
"""
@@ -84,12 +79,27 @@ class Component(abc.ABC, metaclass=ComponentMeta):
8479
component_inputs: dict[str, dict[str, str | bool]]
8580
component_outputs: dict[str, dict[str, str | bool | type]]
8681

87-
@abc.abstractmethod
8882
async def run(self, *args: Any, **kwargs: Any) -> DataModel:
89-
pass
83+
"""This function is planned for deprecation in a future release.
84+
85+
Note: if `run_with_context` is implemented, this method will not be used.
86+
"""
87+
raise NotImplementedError(
88+
"You must implement the `run` or `run_with_context` method. "
89+
"`run` method will be marked for deprecation in a future release."
90+
)
9091

9192
async def run_with_context(
9293
self, context_: RunContext, *args: Any, **kwargs: Any
9394
) -> DataModel:
95+
"""This method is called by the pipeline orchestrator.
96+
The `context_` parameter contains information about
97+
the pipeline run: the `run_id` and a `notify` function
98+
that can be used to send events from the component to
99+
the pipeline callback.
100+
101+
For now, it defaults to calling the `run` method, but it
102+
is meant to replace the `run` method in a future release.
103+
"""
94104
# default behavior to prevent a breaking change
95105
return await self.run(*args, **kwargs)

tests/unit/experimental/pipeline/components.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ async def run(self, number1: int, number2: int = 2) -> IntResultModel:
4545

4646

4747
class ComponentMultiplyWithContext(Component):
48-
async def run(self, number1: int, number2: int) -> IntResultModel:
49-
return IntResultModel(result=number1 * number2)
50-
5148
async def run_with_context(
5249
self, context_: RunContext, number1: int, number2: int = 2
5350
) -> IntResultModel:

0 commit comments

Comments
 (0)