Skip to content

Commit 6851a0b

Browse files
Adds a fixed size text splitter component (#139)
* Added fixed size text splitter class * Updated docs * Updated examples * Added init defaults to fixed size text splitter * Fixed bug in example * Updated E2E tests * Update docs/source/api.rst Co-authored-by: willtai <wtaisen@gmail.com> * Updated fixed size splitter defaults --------- Co-authored-by: willtai <wtaisen@gmail.com>
1 parent fc7d319 commit 6851a0b

File tree

9 files changed

+173
-24
lines changed

9 files changed

+173
-24
lines changed

docs/source/api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ TextSplitter
2727
.. autoclass:: neo4j_graphrag.experimental.components.text_splitters.base.TextSplitter
2828
:members: run
2929

30+
FixedSizeSplitter
31+
=================
32+
33+
.. autoclass:: neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter.FixedSizeSplitter
34+
:members: run
35+
3036
LangChainTextSplitterAdapter
3137
============================
3238

docs/source/user_guide_kg_builder.rst

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,24 @@ Document Splitter
9696
=================
9797

9898
Document splitters, as the name indicate, split documents into smaller chunks
99-
that can be processed within the LLM token limits. Wrappers for LangChain and LlamaIndex
100-
text splitters are included in this package:
99+
that can be processed within the LLM token limits:
101100

101+
.. code:: python
102+
103+
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter
104+
105+
splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
106+
splitter.run(text="Hello World. Life is beautiful.")
107+
108+
109+
Wrappers for LangChain and LlamaIndex text splitters are included in this package:
102110

103111
.. code:: python
104112
105113
from langchain_text_splitters import CharacterTextSplitter
106114
from neo4j_graphrag.experimental.components.text_splitters.langchain import LangChainTextSplitterAdapter
107115
splitter = LangChainTextSplitterAdapter(
108-
CharacterTextSplitter(chunk_size=500, chunk_overlap=100, separator=".")
116+
CharacterTextSplitter(chunk_size=4000, chunk_overlap=200, separator=".")
109117
)
110118
splitter.run(text="Hello World. Life is beautiful.")
111119

examples/pipeline/kg_builder_from_pdf.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Any, Dict, List
2020

2121
import neo4j
22-
from langchain_text_splitters import CharacterTextSplitter
2322
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
2423
LLMEntityRelationExtractor,
2524
OnError,
@@ -31,8 +30,8 @@
3130
SchemaEntity,
3231
SchemaRelation,
3332
)
34-
from neo4j_graphrag.experimental.components.text_splitters.langchain import (
35-
LangChainTextSplitterAdapter,
33+
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
34+
FixedSizeSplitter,
3635
)
3736
from neo4j_graphrag.experimental.pipeline import Component, DataModel
3837
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
@@ -142,8 +141,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
142141
pipe = Pipeline()
143142
pipe.add_component(PdfLoader(), "pdf_loader")
144143
pipe.add_component(
145-
LangChainTextSplitterAdapter(CharacterTextSplitter(separator=". \n")),
146-
"splitter",
144+
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter"
147145
)
148146
pipe.add_component(SchemaBuilder(), "schema")
149147
pipe.add_component(

examples/pipeline/kg_builder_from_text.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import logging.config
1919

2020
import neo4j
21-
from langchain_text_splitters import CharacterTextSplitter
2221
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
2322
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
2423
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -32,8 +31,8 @@
3231
SchemaProperty,
3332
SchemaRelation,
3433
)
35-
from neo4j_graphrag.experimental.components.text_splitters.langchain import (
36-
LangChainTextSplitterAdapter,
34+
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
35+
FixedSizeSplitter,
3736
)
3837
from neo4j_graphrag.experimental.pipeline import Pipeline
3938
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
@@ -63,7 +62,7 @@
6362
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
6463
"""This is where we define and run the KG builder pipeline, instantiating a few
6564
components:
66-
- Text Splitter: in this example we use a text splitter from the LangChain package
65+
- Text Splitter: in this example we use the fixed size text splitter
6766
- Schema Builder: this component takes a list of entities, relationships and
6867
possible triplets as inputs, validate them and return a schema ready to use
6968
for the rest of the pipeline
@@ -76,10 +75,8 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
7675
pipe = Pipeline()
7776
# define the components
7877
pipe.add_component(
79-
LangChainTextSplitterAdapter(
80-
# chunk_size=50 for the sake of this demo
81-
CharacterTextSplitter(chunk_size=50, chunk_overlap=10, separator=".")
82-
),
78+
# chunk_size=50 for the sake of this demo
79+
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200),
8380
"splitter",
8481
)
8582
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")

src/neo4j_graphrag/experimental/components/text_splitters/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,12 @@ class TextSplitter(Component):
2525

2626
@abstractmethod
2727
async def run(self, text: str) -> TextChunks:
28+
"""Splits a piece of text into chunks.
29+
30+
Args:
31+
text (str): The text to be split.
32+
33+
Returns:
34+
TextChunks: A list of chunks.
35+
"""
2836
pass
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from pydantic import validate_call
16+
17+
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
18+
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks
19+
20+
21+
class FixedSizeSplitter(TextSplitter):
22+
"""Text splitter which splits the input text into fixed size chunks with optional overlap.
23+
24+
Args:
25+
chunk_size (int): The number of characters in each chunk.
26+
chunk_overlap (int): The number of characters from the previous chunk to overlap with each chunk. Must be less than `chunk_size`.
27+
28+
Example:
29+
30+
.. code-block:: python
31+
32+
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter
33+
from neo4j_graphrag.experimental.pipeline import Pipeline
34+
35+
pipeline = Pipeline()
36+
text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
37+
pipeline.add_component(text_splitter, "text_splitter")
38+
"""
39+
40+
@validate_call
41+
def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None:
42+
if chunk_overlap >= chunk_size:
43+
raise ValueError("chunk_overlap must be strictly less than chunk_size")
44+
self.chunk_size = chunk_size
45+
self.chunk_overlap = chunk_overlap
46+
47+
@validate_call
48+
async def run(self, text: str) -> TextChunks:
49+
"""Splits a piece of text into chunks.
50+
51+
Args:
52+
text (str): The text to be split.
53+
54+
Returns:
55+
TextChunks: A list of chunks.
56+
"""
57+
chunks = []
58+
index = 0
59+
for i in range(0, len(text), self.chunk_size - self.chunk_overlap):
60+
start = i
61+
end = min(start + self.chunk_size, len(text))
62+
chunk_text = text[start:end]
63+
chunks.append(TextChunk(text=chunk_text, index=index))
64+
index += 1
65+
return TextChunks(chunks=chunks)

src/neo4j_graphrag/experimental/components/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class TextChunk(BaseModel):
2626
2727
Attributes:
2828
text (str): The raw chunk text.
29+
index (int): The position of this chunk in the original document.
2930
metadata (Optional[dict[str, Any]]): Metadata associated with this chunk such as the id of the next chunk in the original document.
3031
"""
3132

tests/e2e/test_kg_builder_pipeline_e2e.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import neo4j
2222
import pytest
23-
from langchain_text_splitters import CharacterTextSplitter
2423
from neo4j_graphrag.embeddings.base import Embedder
2524
from neo4j_graphrag.exceptions import LLMGenerationError
2625
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
@@ -35,8 +34,8 @@
3534
SchemaProperty,
3635
SchemaRelation,
3736
)
38-
from neo4j_graphrag.experimental.components.text_splitters.langchain import (
39-
LangChainTextSplitterAdapter,
37+
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
38+
FixedSizeSplitter,
4039
)
4140
from neo4j_graphrag.experimental.pipeline import Pipeline
4241
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
@@ -63,10 +62,8 @@ def schema_builder() -> SchemaBuilder:
6362

6463

6564
@pytest.fixture
66-
def text_splitter() -> LangChainTextSplitterAdapter:
67-
return LangChainTextSplitterAdapter(
68-
CharacterTextSplitter(chunk_size=50, chunk_overlap=10, separator="\n\n")
69-
)
65+
def text_splitter() -> FixedSizeSplitter:
66+
return FixedSizeSplitter(chunk_size=500, chunk_overlap=100)
7067

7168

7269
@pytest.fixture
@@ -89,7 +86,7 @@ def kg_writer(driver: neo4j.Driver) -> Neo4jWriter:
8986

9087
@pytest.fixture
9188
def kg_builder_pipeline(
92-
text_splitter: LangChainTextSplitterAdapter,
89+
text_splitter: FixedSizeSplitter,
9390
chunk_embedder: TextChunkEmbedder,
9491
schema_builder: SchemaBuilder,
9592
entity_relation_extractor: LLMEntityRelationExtractor,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pytest
16+
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
17+
FixedSizeSplitter,
18+
)
19+
from neo4j_graphrag.experimental.components.types import TextChunk
20+
21+
22+
@pytest.mark.asyncio
23+
async def test_split_text_no_overlap() -> None:
24+
text = "may thy knife chip and shatter"
25+
chunk_size = 5
26+
chunk_overlap = 0
27+
splitter = FixedSizeSplitter(chunk_size, chunk_overlap)
28+
chunks = await splitter.run(text)
29+
expected_chunks = [
30+
TextChunk(text="may t", index=0),
31+
TextChunk(text="hy kn", index=1),
32+
TextChunk(text="ife c", index=2),
33+
TextChunk(text="hip a", index=3),
34+
TextChunk(text="nd sh", index=4),
35+
TextChunk(text="atter", index=5),
36+
]
37+
assert chunks.chunks == expected_chunks
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_split_text_with_overlap() -> None:
42+
text = "may thy knife chip and shatter"
43+
chunk_size = 10
44+
chunk_overlap = 2
45+
splitter = FixedSizeSplitter(chunk_size, chunk_overlap)
46+
chunks = await splitter.run(text)
47+
expected_chunks = [
48+
TextChunk(text="may thy kn", index=0),
49+
TextChunk(text="knife chip", index=1),
50+
TextChunk(text="ip and sha", index=2),
51+
TextChunk(text="hatter", index=3),
52+
]
53+
assert chunks.chunks == expected_chunks
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_split_text_empty_string() -> None:
58+
text = ""
59+
chunk_size = 5
60+
chunk_overlap = 1
61+
splitter = FixedSizeSplitter(chunk_size, chunk_overlap)
62+
chunks = await splitter.run(text)
63+
assert chunks.chunks == []
64+
65+
66+
def test_invalid_chunk_overlap() -> None:
67+
with pytest.raises(ValueError) as excinfo:
68+
FixedSizeSplitter(5, 5)
69+
assert "chunk_overlap must be strictly less than chunk_size" in str(excinfo)

0 commit comments

Comments
 (0)