Skip to content

Commit 0f01b14

Browse files
Merge pull request #22 from shcherbak-ai/dev
v0.4.0
2 parents f1ffee4 + 4e47bfe commit 0f01b14

File tree

11 files changed

+152
-16
lines changed

11 files changed

+152
-16
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55

66
- **Refactor**: Code reorganization that doesn't change functionality but improves structure or maintainability
77

8+
## [0.4.0](https://github.com/shcherbak-ai/contextgem/releases/tag/v0.4.0) - 2025-05-20
9+
### Added
10+
- Support for local SaT model paths in Document's `sat_model_id` parameter
11+
812
## [0.3.0](https://github.com/shcherbak-ai/contextgem/releases/tag/v0.3.0) - 2025-05-19
913
### Added
1014
- Expanded JsonObjectConcept to support nested class hierarchies, nested dictionary structures, lists containing objects, and literal types.

contextgem/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
ContextGem - Effortless LLM extraction from documents
2121
"""
2222

23-
__version__ = "0.3.0"
23+
__version__ = "0.4.0"
2424
__author__ = "Shcherbak AI AS"
2525

2626
from contextgem.public import (

contextgem/internal/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
NonEmptyStr,
7373
ReferenceDepth,
7474
SaTModelId,
75+
StandardSaTModelId,
7576
_deserialize_type_hint,
7677
_dynamic_pydantic_model,
7778
_format_dict_structure,
@@ -126,6 +127,7 @@
126127
"DefaultPromptType",
127128
"ReferenceDepth",
128129
"SaTModelId",
130+
"StandardSaTModelId",
129131
"LanguageRequirement",
130132
"JustificationDepth",
131133
"AsyncCalsAndKwargs",

contextgem/internal/typings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
NonEmptyStr,
3030
ReferenceDepth,
3131
SaTModelId,
32+
StandardSaTModelId,
3233
)
3334
from contextgem.internal.typings.strings_to_types import _deserialize_type_hint
3435
from contextgem.internal.typings.typed_class_utils import (
@@ -58,6 +59,7 @@
5859
"DefaultPromptType",
5960
"ReferenceDepth",
6061
"SaTModelId",
62+
"StandardSaTModelId",
6163
"LanguageRequirement",
6264
"JustificationDepth",
6365
"AsyncCalsAndKwargs",

contextgem/internal/typings/aliases.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727

2828
import sys
2929
from decimal import Decimal
30-
from typing import Annotated, Any, Callable, Coroutine, Literal, TypeVar
30+
from pathlib import Path
31+
from typing import Annotated, Any, Callable, Coroutine, Literal, TypeVar, Union
3132

3233
from pydantic import Field, StrictStr, StringConstraints
3334

@@ -54,7 +55,8 @@
5455

5556
ReferenceDepth = Literal["paragraphs", "sentences"]
5657

57-
SaTModelId = Literal[
58+
# Define standard SaT model IDs as a separate type
59+
StandardSaTModelId = Literal[
5860
"sat-1l",
5961
"sat-1l-sm",
6062
"sat-3l",
@@ -66,6 +68,13 @@
6668
"sat-12l-sm",
6769
]
6870

71+
# Combined type for sat_model_id parameter
72+
SaTModelId = Union[
73+
StandardSaTModelId,
74+
str, # Local path as a string
75+
Path, # Local path as a Path object
76+
]
77+
6978
LanguageRequirement = Literal["en", "adapt"]
7079

7180
JustificationDepth = Literal["brief", "balanced", "comprehensive"]

contextgem/internal/utils.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from collections import defaultdict
3030
from functools import lru_cache
3131
from pathlib import Path
32-
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, TypeVar
32+
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, TypeVar, get_args
3333

3434
from jinja2 import Environment, Template, nodes
3535
from wtpsplit import SaT
@@ -51,6 +51,7 @@
5151
ExtractedInstanceType,
5252
ReferenceDepth,
5353
SaTModelId,
54+
StandardSaTModelId,
5455
)
5556

5657
T = TypeVar("T")
@@ -586,17 +587,69 @@ def _validate_parsed_llm_output(
586587
def _get_sat_model(model_id: SaTModelId = "sat-3l-sm") -> SaT:
587588
"""
588589
Retrieves and caches a SaT model to be used for paragraphs and sentence segmentation.
590+
Performs validation of the model ID or path before attempting to load the model.
589591
590592
:param model_id:
591593
The identifier of the SaT model. Defaults to "sat-3l-sm".
594+
Can be:
595+
- A standard SaT model ID (e.g., "sat-3l-sm")
596+
- A local path to a SaT model directory (as a string or Path object)
592597
593598
:return:
594599
An instance of the SaT model associated with the given `model_id`.
595-
"""
596-
logger.info(f"Loading SaT model {model_id}...")
597-
model = SaT(model_id)
598-
logger.info(f"SaT model {model_id} loaded.")
599-
return model
600+
601+
:raises ValueError:
602+
If the provided path doesn't exist or is not a directory.
603+
:raises RuntimeError:
604+
If the provided path exists but does not contain a valid SaT model.
605+
"""
606+
# Convert Path object to string if needed
607+
if isinstance(model_id, Path):
608+
model_id = str(model_id)
609+
610+
# Check if it's a standard model ID
611+
is_standard_model = False
612+
if isinstance(model_id, str):
613+
# Get standard models directly from the type definition
614+
standard_models = get_args(StandardSaTModelId)
615+
is_standard_model = model_id in standard_models
616+
617+
# Determine if it's a local path (but not a standard model ID)
618+
is_local_path = False
619+
if isinstance(model_id, str) and not is_standard_model:
620+
path = Path(model_id)
621+
622+
# Validate that the path exists and is a directory
623+
if not path.exists() or not path.is_dir():
624+
raise ValueError(
625+
f"The provided SaT model path '{model_id}' does not exist or is not a directory."
626+
)
627+
628+
is_local_path = True
629+
630+
# Log appropriate message
631+
if is_local_path:
632+
logger.info(f"Loading SaT model from local path {model_id}...")
633+
else:
634+
logger.info(f"Loading SaT model {model_id}...")
635+
636+
# Attempt to load the model
637+
try:
638+
model = SaT(model_id)
639+
logger.info(f"SaT model loaded successfully.")
640+
return model
641+
except Exception as e:
642+
if is_local_path:
643+
# If it's a local path that exists but isn't a valid SaT model
644+
logger.error(f"Failed to load SaT model from path '{model_id}': {str(e)}")
645+
raise RuntimeError(
646+
f"The directory at '{model_id}' exists but does not contain a valid SaT model. "
647+
f"Error: {str(e)}"
648+
) from e
649+
else:
650+
# For standard model IDs or other errors
651+
logger.error(f"Failed to load SaT model '{model_id}': {str(e)}")
652+
raise
600653

601654

602655
def _group_instances_by_fields(

contextgem/public/documents.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import itertools
4040
import warnings
4141
from copy import deepcopy
42+
from pathlib import Path
4243
from typing import Any, Literal, Optional
4344

4445
from pydantic import Field, field_validator, model_validator
@@ -81,8 +82,10 @@ class Document(_AssignedInstancesProcessor):
8182
:ivar paragraph_segmentation_mode: Mode for paragraph segmentation. When set to "sat",
8283
uses a SaT (Segment Any Text https://arxiv.org/abs/2406.16678) model. Defaults to "newlines".
8384
:type paragraph_segmentation_mode: Literal["newlines", "sat"]
84-
:ivar sat_model_id: SaT model ID for paragraph/sentence segmentation.
85-
Defaults to "sat-3l-sm". See https://github.com/segment-any-text/wtpsplit for the list of available models.
85+
:ivar sat_model_id: SaT model ID for paragraph/sentence segmentation or a local path to a SaT model.
86+
For model IDs, defaults to "sat-3l-sm". See https://github.com/segment-any-text/wtpsplit
87+
for the list of available models. For local paths, provide either a string path or a Path
88+
object pointing to the directory containing the SaT model.
8689
:type sat_model_id: SaTModelId
8790
8891
Note:
@@ -285,6 +288,21 @@ def _validate_images(cls, images: list[Image]) -> list[Image]:
285288
seen.add(image.base64_data)
286289
return images
287290

291+
@field_validator("sat_model_id")
292+
@classmethod
293+
def _validate_sat_model_id(cls, sat_model_id: SaTModelId) -> str:
294+
"""
295+
Validates and converts the sat_model_id to ensure it's a string.
296+
If a Path object is provided, it's converted to a string representation.
297+
This conversion ensures the document remains fully serializable.
298+
299+
:param sat_model_id: The SaT model ID or path to validate
300+
:return: String representation of the model ID or path
301+
"""
302+
if isinstance(sat_model_id, Path):
303+
return str(sat_model_id)
304+
return sat_model_id
305+
288306
@model_validator(mode="before")
289307
@classmethod
290308
def _validate_document_pre(cls, data: Any) -> Any:

docs/docs-raw-for-llm.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5575,8 +5575,11 @@ class contextgem.public.documents.Document(**data)
55755575
"newlines".
55765576

55775577
* **sat_model_id** -- SaT model ID for paragraph/sentence
5578-
segmentation. Defaults to "sat-3l-sm". See https://github.com
5579-
/segment-any-text/wtpsplit for the list of available models.
5578+
segmentation or a local path to a SaT model. For model IDs,
5579+
defaults to "sat-3l-sm". See https://github.com/segment-any-
5580+
text/wtpsplit for the list of available models. For local
5581+
paths, provide either a string path or a Path object pointing
5582+
to the directory containing the SaT model.
55805583

55815584
Parameters:
55825585
* **custom_data** (*dict*)

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
project = "ContextGem"
2323
copyright = "2025, Shcherbak AI AS"
2424
author = "Sergii Shcherbak"
25-
release = "0.3.0"
25+
release = "0.4.0"
2626

2727

2828
# Add path to the package

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "contextgem"
3-
version = "0.3.0"
3+
version = "0.4.0"
44
description = "Effortless LLM extraction from documents"
55
authors = [
66
{name = "shcherbak-ai", email = "sergii@shcherbak.ai"}

tests/test_all.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import os
2626
import sys
27+
import tempfile
2728
import warnings
2829
import xml.etree.ElementTree as ET
2930
import zipfile
@@ -71,7 +72,7 @@
7172
dedicated_stream,
7273
logger,
7374
)
74-
from contextgem.internal.utils import _split_text_into_paragraphs
75+
from contextgem.internal.utils import _get_sat_model, _split_text_into_paragraphs
7576
from contextgem.public.utils import JsonObjectClassStruct
7677
from tests.utils import (
7778
VCR_FILTER_HEADERS,
@@ -2511,6 +2512,50 @@ def test_init_document_and_pipeline(self, context: Document | DocumentPipeline):
25112512
with pytest.raises(ValueError):
25122513
context.add_concepts([concept, concept])
25132514

2515+
def test_local_sat_model(self):
2516+
"""
2517+
Tests the loading of a local SAT model.
2518+
"""
2519+
2520+
# Test nonexistent path
2521+
with pytest.raises(ValueError) as exc_info:
2522+
non_existent_path = "/nonexistent/path/to/model"
2523+
_get_sat_model(non_existent_path)
2524+
assert "does not exist or is not a directory" in str(exc_info.value)
2525+
# Document creation should also fail
2526+
with pytest.raises(ValueError):
2527+
Document(
2528+
raw_text="Sample text",
2529+
paragraph_segmentation_mode="sat",
2530+
sat_model_id=non_existent_path,
2531+
)
2532+
2533+
# Test file path (not a directory)
2534+
with tempfile.NamedTemporaryFile() as temp_file:
2535+
with pytest.raises(ValueError) as exc_info:
2536+
_get_sat_model(temp_file.name)
2537+
assert "does not exist or is not a directory" in str(exc_info.value)
2538+
# Document creation should also fail
2539+
with pytest.raises(ValueError):
2540+
Document(
2541+
raw_text="Sample text",
2542+
paragraph_segmentation_mode="sat",
2543+
sat_model_id=temp_file.name,
2544+
)
2545+
2546+
# Test valid path but invalid model
2547+
with tempfile.TemporaryDirectory() as temp_dir:
2548+
with pytest.raises(RuntimeError) as exc_info:
2549+
_get_sat_model(temp_dir)
2550+
assert "does not contain a valid SaT model" in str(exc_info.value)
2551+
# Document creation should also fail
2552+
with pytest.raises(RuntimeError):
2553+
Document(
2554+
raw_text="Sample text",
2555+
paragraph_segmentation_mode="sat",
2556+
sat_model_id=temp_dir,
2557+
)
2558+
25142559
@pytest.mark.vcr()
25152560
def test_system_messages(self):
25162561
"""

0 commit comments

Comments
 (0)