Skip to content

Commit 80e8b83

Browse files
Fixes #2998 add kwargs to ChunkrReaderConfig (#3228)
Co-authored-by: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com>
1 parent 982543e commit 80e8b83

File tree

2 files changed

+95
-6
lines changed

2 files changed

+95
-6
lines changed

camel/loaders/chunkr_reader.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,23 @@ class ChunkrReaderConfig:
3434
high_resolution (bool, optional): Whether to use high resolution OCR.
3535
(default: :obj:`True`)
3636
ocr_strategy (str, optional): The OCR strategy. Defaults to 'Auto'.
37+
**kwargs: Additional keyword arguments to pass to the Chunkr Configuration.
38+
This accepts all other Configuration parameters such as expires_in,
39+
pipeline, segment_processing, segmentation_strategy, etc.
40+
See: https://github.com/lumina-ai-inc/chunkr/blob/main/core/src/models/task.rs#L749
3741
"""
3842

3943
def __init__(
4044
self,
4145
chunk_processing: int = 512,
4246
high_resolution: bool = True,
4347
ocr_strategy: str = "Auto",
48+
**kwargs,
4449
):
4550
self.chunk_processing = chunk_processing
4651
self.high_resolution = high_resolution
4752
self.ocr_strategy = ocr_strategy
53+
self.kwargs = kwargs
4854

4955

5056
class ChunkrReader:
@@ -175,11 +181,7 @@ def _to_chunkr_configuration(
175181
Returns:
176182
Configuration: Chunkr SDK configuration.
177183
"""
178-
from chunkr_ai.models import (
179-
ChunkProcessing,
180-
Configuration,
181-
OcrStrategy,
182-
)
184+
from chunkr_ai.models import ChunkProcessing, Configuration, OcrStrategy
183185

184186
return Configuration(
185187
chunk_processing=ChunkProcessing(
@@ -190,4 +192,5 @@ def _to_chunkr_configuration(
190192
"Auto": OcrStrategy.AUTO,
191193
"All": OcrStrategy.ALL,
192194
}.get(chunkr_config.ocr_strategy, OcrStrategy.ALL),
195+
**chunkr_config.kwargs,
193196
)

test/loaders/test_chunkr_reader.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import json
1616
import unittest
17-
from unittest.mock import AsyncMock, patch
17+
from unittest.mock import AsyncMock, MagicMock, patch
1818

1919
from chunkr_ai.models import Status
2020

@@ -119,6 +119,92 @@ async def test_get_task_output_poll_error(self, mock_chunkr_class):
119119
mock_chunkr_instance.get_task.assert_called_once_with("test_task_id")
120120
mock_task.poll.assert_called_once()
121121

122+
def test_chunkr_reader_config_defaults(self):
123+
"""Test ChunkrReaderConfig with default values."""
124+
config = ChunkrReaderConfig()
125+
126+
self.assertEqual(config.chunk_processing, 512)
127+
self.assertEqual(config.high_resolution, True)
128+
self.assertEqual(config.ocr_strategy, "Auto")
129+
self.assertEqual(config.kwargs, {})
130+
131+
def test_chunkr_reader_config_custom_values(self):
132+
"""Test ChunkrReaderConfig with custom values."""
133+
config = ChunkrReaderConfig(
134+
chunk_processing=1024,
135+
high_resolution=False,
136+
ocr_strategy="All",
137+
)
138+
139+
self.assertEqual(config.chunk_processing, 1024)
140+
self.assertEqual(config.high_resolution, False)
141+
self.assertEqual(config.ocr_strategy, "All")
142+
143+
def test_chunkr_reader_config_with_kwargs(self):
144+
"""Test ChunkrReaderConfig with additional kwargs."""
145+
config = ChunkrReaderConfig(
146+
chunk_processing=2048,
147+
expires_in=3600,
148+
pipeline="custom_pipeline",
149+
)
150+
151+
self.assertEqual(config.chunk_processing, 2048)
152+
self.assertEqual(config.kwargs["expires_in"], 3600)
153+
self.assertEqual(config.kwargs["pipeline"], "custom_pipeline")
154+
155+
@patch('chunkr_ai.models.Configuration')
156+
@patch('chunkr_ai.models.ChunkProcessing')
157+
@patch('chunkr_ai.models.OcrStrategy')
158+
def test_to_chunkr_configuration_auto_strategy(
159+
self, mock_ocr_strategy, mock_chunk_processing, mock_configuration
160+
):
161+
"""Test _to_chunkr_configuration with Auto OCR strategy."""
162+
config = ChunkrReaderConfig(
163+
chunk_processing=512, high_resolution=True, ocr_strategy="Auto"
164+
)
165+
166+
mock_chunk_processing_instance = MagicMock()
167+
mock_chunk_processing.return_value = mock_chunk_processing_instance
168+
mock_ocr_strategy.AUTO = "AUTO"
169+
170+
self.reader._to_chunkr_configuration(config)
171+
172+
mock_chunk_processing.assert_called_once_with(target_length=512)
173+
mock_configuration.assert_called_once_with(
174+
chunk_processing=mock_chunk_processing_instance,
175+
high_resolution=True,
176+
ocr_strategy="AUTO",
177+
)
178+
179+
@patch('chunkr_ai.models.Configuration')
180+
@patch('chunkr_ai.models.ChunkProcessing')
181+
@patch('chunkr_ai.models.OcrStrategy')
182+
def test_to_chunkr_configuration_with_kwargs(
183+
self, mock_ocr_strategy, mock_chunk_processing, mock_configuration
184+
):
185+
"""Test _to_chunkr_configuration with additional kwargs."""
186+
config = ChunkrReaderConfig(
187+
chunk_processing=512,
188+
high_resolution=True,
189+
ocr_strategy="Auto",
190+
expires_in=7200,
191+
pipeline="test_pipeline",
192+
)
193+
194+
mock_chunk_processing_instance = MagicMock()
195+
mock_chunk_processing.return_value = mock_chunk_processing_instance
196+
mock_ocr_strategy.AUTO = "AUTO"
197+
198+
result = self.reader._to_chunkr_configuration(config)
199+
200+
mock_chunk_processing.assert_called_once_with(target_length=512)
201+
mock_configuration.assert_called_once_with(
202+
chunk_processing=mock_chunk_processing_instance,
203+
high_resolution=True,
204+
ocr_strategy="AUTO",
205+
expires_in=7200,
206+
pipeline="test_pipeline",
207+
)
122208

123209
if __name__ == "__main__":
124210
unittest.main()

0 commit comments

Comments
 (0)