Skip to content

Commit 191b1b1

Browse files
Improve text splitter to avoid cutting words in chunks (#242)
* Improve text splitter to avoid cutting words in chunks
1 parent 68cba61 commit 191b1b1

File tree

10 files changed

+254
-22
lines changed

10 files changed

+254
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
### Changed
2929
- Updated LLM implementations to handle message history consistently across providers.
3030
- The `id_prefix` parameter in the `LexicalGraphConfig` is deprecated.
31+
- Changed the default behaviour of `FixedSizeSplitter` to avoid words cut-off in the chunks whenever it is possible.
3132

3233
### Fixed
3334
- IDs for the Document and Chunk nodes in the lexical graph are now randomly generated and unique across multiple runs, fixing issues in the lexical graph where relationships were created between chunks that were created by different pipeline runs.

docs/source/user_guide_kg_builder.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,9 +581,12 @@ that can be processed within the LLM token limits:
581581
582582
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter
583583
584-
splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
584+
splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False)
585585
splitter.run(text="Hello World. Life is beautiful.")
586586
587+
.. note::
588+
589+
`approximate` flag is by default set to True to ensure clean chunk start and end (i.e. avoid words cut in the middle) whenever it is possible.
587590

588591
Wrappers for LangChain and LlamaIndex text splitters are included in this package:
589592

examples/customize/build_graph/components/splitters/fixed_size_splitter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
async def main() -> TextChunks:
88
splitter = FixedSizeSplitter(
9-
# optionally, configure chunk_size and chunk_overlap
9+
# optionally, configure chunk_size, chunk_overlap, and approximate flag
1010
# chunk_size=4000,
1111
# chunk_overlap=200,
12+
# approximate = False
1213
)
1314
chunks = await splitter.run(text="text to split")
1415
return chunks

examples/customize/build_graph/pipeline/kg_builder_from_pdf.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import asyncio
1818
import logging
1919

20-
import neo4j
2120
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
2221
LLMEntityRelationExtractor,
2322
OnError,
@@ -35,6 +34,8 @@
3534
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
3635
from neo4j_graphrag.llm import LLMInterface, OpenAILLM
3736

37+
import neo4j
38+
3839
logging.basicConfig(level=logging.INFO)
3940

4041

@@ -83,7 +84,8 @@ async def define_and_run_pipeline(
8384
pipe = Pipeline()
8485
pipe.add_component(PdfLoader(), "pdf_loader")
8586
pipe.add_component(
86-
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter"
87+
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False),
88+
"splitter",
8789
)
8890
pipe.add_component(SchemaBuilder(), "schema")
8991
pipe.add_component(

examples/customize/build_graph/pipeline/kg_builder_from_text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import asyncio
1818

19-
import neo4j
2019
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
2120
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
2221
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -37,6 +36,8 @@
3736
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
3837
from neo4j_graphrag.llm import LLMInterface, OpenAILLM
3938

39+
import neo4j
40+
4041

4142
async def define_and_run_pipeline(
4243
neo4j_driver: neo4j.Driver, llm: LLMInterface
@@ -58,7 +59,7 @@ async def define_and_run_pipeline(
5859
# define the components
5960
pipe.add_component(
6061
# chunk_size=50 for the sake of this demo
61-
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200),
62+
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False),
6263
"splitter",
6364
)
6465
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")

examples/customize/build_graph/pipeline/lexical_graph_builder_from_text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import asyncio
44

5-
import neo4j
65
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
76
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
87
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
@@ -14,6 +13,8 @@
1413
from neo4j_graphrag.experimental.pipeline import Pipeline
1514
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
1615

16+
import neo4j
17+
1718

1819
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
1920
"""This is where we define and run the Lexical Graph builder pipeline, instantiating
@@ -27,7 +28,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
2728
pipe = Pipeline()
2829
# define the components
2930
pipe.add_component(
30-
FixedSizeSplitter(chunk_size=20, chunk_overlap=1),
31+
FixedSizeSplitter(chunk_size=20, chunk_overlap=1, approximate=False),
3132
"splitter",
3233
)
3334
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")

examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import asyncio
99

10-
import neo4j
1110
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
1211
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
1312
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -29,6 +28,8 @@
2928
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
3029
from neo4j_graphrag.llm import LLMInterface, OpenAILLM
3130

31+
import neo4j
32+
3233

3334
async def define_and_run_pipeline(
3435
neo4j_driver: neo4j.Driver,
@@ -56,7 +57,7 @@ async def define_and_run_pipeline(
5657
pipe = Pipeline()
5758
# define the components
5859
pipe.add_component(
59-
FixedSizeSplitter(chunk_size=200, chunk_overlap=50),
60+
FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False),
6061
"splitter",
6162
)
6263
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")

examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import asyncio
1010

11-
import neo4j
1211
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
1312
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
1413
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -31,6 +30,8 @@
3130
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
3231
from neo4j_graphrag.llm import LLMInterface, OpenAILLM
3332

33+
import neo4j
34+
3435

3536
async def build_lexical_graph(
3637
neo4j_driver: neo4j.Driver,
@@ -47,7 +48,7 @@ async def build_lexical_graph(
4748
pipe = Pipeline()
4849
# define the components
4950
pipe.add_component(
50-
FixedSizeSplitter(chunk_size=200, chunk_overlap=50),
51+
FixedSizeSplitter(chunk_size=200, chunk_overlap=50, approximate=False),
5152
"splitter",
5253
)
5354
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")

src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py

Lines changed: 91 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,66 @@
1818
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks
1919

2020

21+
def _adjust_chunk_start(text: str, approximate_start: int) -> int:
22+
"""
23+
Shift the starting index backward if it lands in the middle of a word.
24+
If no whitespace is found, use the proposed start.
25+
26+
Args:
27+
text (str): The text being split.
28+
approximate_start (int): The initial starting index of the chunk.
29+
30+
Returns:
31+
int: The adjusted starting index, ensuring the chunk does not begin in the
32+
middle of a word if possible.
33+
"""
34+
start = approximate_start
35+
if start > 0 and not text[start].isspace() and not text[start - 1].isspace():
36+
while start > 0 and not text[start - 1].isspace():
37+
start -= 1
38+
39+
# fallback if no whitespace is found
40+
if start == 0 and not text[0].isspace():
41+
start = approximate_start
42+
return start
43+
44+
45+
def _adjust_chunk_end(text: str, start: int, approximate_end: int) -> int:
46+
"""
47+
Shift the ending index backward if it lands in the middle of a word.
48+
If no whitespace is found, use 'approximate_end'.
49+
50+
Args:
51+
text (str): The full text being split.
52+
start (int): The adjusted starting index for this chunk.
53+
approximate_end (int): The initial end index.
54+
55+
Returns:
56+
int: The adjusted ending index, ensuring the chunk does not end in the middle of
57+
a word if possible.
58+
"""
59+
end = approximate_end
60+
if end < len(text):
61+
while end > start and not text[end].isspace() and not text[end - 1].isspace():
62+
end -= 1
63+
64+
# fallback if no whitespace is found
65+
if end == start:
66+
end = approximate_end
67+
return end
68+
69+
2170
class FixedSizeSplitter(TextSplitter):
22-
"""Text splitter which splits the input text into fixed size chunks with optional overlap.
71+
"""Text splitter which splits the input text into fixed or approximate fixed size
72+
chunks with optional overlap.
2373
2474
Args:
2575
chunk_size (int): The number of characters in each chunk.
26-
chunk_overlap (int): The number of characters from the previous chunk to overlap with each chunk. Must be less than `chunk_size`.
76+
chunk_overlap (int): The number of characters from the previous chunk to overlap
77+
with each chunk. Must be less than `chunk_size`.
78+
approximate (bool): If True, avoids splitting words in the middle at chunk
79+
boundaries. Defaults to True.
80+
2781
2882
Example:
2983
@@ -33,16 +87,21 @@ class FixedSizeSplitter(TextSplitter):
3387
from neo4j_graphrag.experimental.pipeline import Pipeline
3488
3589
pipeline = Pipeline()
36-
text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
90+
text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=True)
3791
pipeline.add_component(text_splitter, "text_splitter")
3892
"""
3993

4094
@validate_call
41-
def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None:
95+
def __init__(
96+
self, chunk_size: int = 4000, chunk_overlap: int = 200, approximate: bool = True
97+
) -> None:
98+
if chunk_size <= 0:
99+
raise ValueError("chunk_size must be strictly greater than 0")
42100
if chunk_overlap >= chunk_size:
43101
raise ValueError("chunk_overlap must be strictly less than chunk_size")
44102
self.chunk_size = chunk_size
45103
self.chunk_overlap = chunk_overlap
104+
self.approximate = approximate
46105

47106
@validate_call
48107
async def run(self, text: str) -> TextChunks:
@@ -56,10 +115,35 @@ async def run(self, text: str) -> TextChunks:
56115
"""
57116
chunks = []
58117
index = 0
59-
for i in range(0, len(text), self.chunk_size - self.chunk_overlap):
60-
start = i
61-
end = min(start + self.chunk_size, len(text))
118+
step = self.chunk_size - self.chunk_overlap
119+
text_length = len(text)
120+
approximate_start = 0
121+
skip_adjust_chunk_start = False
122+
end = 0
123+
124+
while end < text_length:
125+
if self.approximate:
126+
start = (
127+
approximate_start
128+
if skip_adjust_chunk_start
129+
else _adjust_chunk_start(text, approximate_start)
130+
)
131+
# adjust start and end to avoid cutting words in the middle
132+
approximate_end = min(start + self.chunk_size, text_length)
133+
end = _adjust_chunk_end(text, start, approximate_end)
134+
# when avoiding splitting words in the middle is not possible, revert to
135+
# initial chunk end and skip adjusting next chunk start
136+
skip_adjust_chunk_start = end == approximate_end
137+
else:
138+
# apply fixed size splitting with possibly words cut in half at chunk
139+
# boundaries
140+
start = approximate_start
141+
end = min(start + self.chunk_size, text_length)
142+
62143
chunk_text = text[start:end]
63144
chunks.append(TextChunk(text=chunk_text, index=index))
64145
index += 1
146+
147+
approximate_start = start + step
148+
65149
return TextChunks(chunks=chunks)

0 commit comments

Comments
 (0)