Skip to content

Commit 1e76e89

Browse files
author
Marta
committed
fix tests
1 parent 7f0bcfe commit 1e76e89

File tree

9 files changed

+389
-18
lines changed

9 files changed

+389
-18
lines changed

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description = "Logprobs for OpenAI Structured Outputs"
55
authors = [{ name = "Sarus Technologies", email = "nicolas.grislain@gmail.com" }]
66
readme = "README.md"
77
keywords = ['python']
8-
requires-python = ">=3.9,<4.0"
8+
requires-python = ">=3.10,<4.0"
99
classifiers = [
1010
"Intended Audience :: Developers",
1111
"Programming Language :: Python",
@@ -18,9 +18,9 @@ classifiers = [
1818
"Topic :: Software Development :: Libraries :: Python Modules",
1919
]
2020
dependencies = [
21-
"openai>=1.58.1",
22-
"pydantic>=2.10.4",
23-
"lark>=1.2.2",
21+
"openai~=1.58.1",
22+
"pydantic~=2.10.4",
23+
"lark~=1.2.2",
2424
]
2525

2626
[project.urls]

structured_logprobs/helpers.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from typing import Any
2+
3+
from lark import Lark, Token, Transformer_NonRecursive, Tree, v_args
4+
from lark.tree import Meta
5+
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
6+
from pydantic import BaseModel
7+
8+
9+
class HasProb(BaseModel):
10+
value: Any
11+
start: int
12+
end: int
13+
logprob: float
14+
15+
16+
# Define a grammar for JSON
17+
json_grammar = r"""
18+
start: value
19+
20+
?value: object #'?' is a Lark convention indicating that the rule can return the value directly instead of creating a separate parse tree node.
21+
| array
22+
| string
23+
| SIGNED_NUMBER -> number #'-> number' specifies an alias for the rule
24+
| "true"
25+
| "false"
26+
| "null"
27+
28+
array : "[" [value ("," value)*] "]"
29+
object : "{" [pair ("," pair)*] "}"
30+
pair : key ":" value
31+
key : ESCAPED_STRING
32+
33+
string : ESCAPED_STRING
34+
35+
%import common.ESCAPED_STRING
36+
%import common.SIGNED_NUMBER
37+
%import common.WS
38+
%ignore WS
39+
"""
40+
41+
42+
# Transformer that processes the tree and substitutes each atomic value with the cumulative log-probability of its tokens
43+
@v_args(meta=True)
44+
class Extractor(Transformer_NonRecursive):
45+
def __init__(self, tokens: list[ChatCompletionTokenLogprob], token_indices: list[int]):
46+
super().__init__()
47+
self.tokens = tokens
48+
self.token_indices = token_indices
49+
50+
def _compute_logprob_sum(self, start: int, end: int) -> float:
51+
token_start = self.token_indices[start]
52+
token_end = self.token_indices[end]
53+
sum_logporb = sum(self.tokens[i].logprob for i in range(token_start, token_end))
54+
return sum_logporb
55+
56+
def number(self, meta: Meta, children: list[Token]) -> float:
57+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
58+
return logprob_sum
59+
60+
def string(self, meta: Meta, children: list[Token]) -> float:
61+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
62+
return logprob_sum
63+
64+
def true(self, meta: Meta, children: list[Token]) -> float:
65+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
66+
return logprob_sum
67+
68+
def false(self, meta: Meta, children: list[Token]) -> float:
69+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
70+
return logprob_sum
71+
72+
def null(self, meta: Meta, children: list[Token]) -> None:
73+
return None
74+
75+
def array(self, meta: Meta, children: list[Any]) -> list[float]:
76+
return children
77+
78+
def object(self, meta: Meta, children: list[tuple[str, Any]]) -> dict[str, Any]:
79+
result = {}
80+
for key, value in children:
81+
result[key] = value
82+
return result
83+
84+
def pair(self, meta: Meta, children: list[Any]) -> tuple[str, Any]:
85+
value = children[1]
86+
key = children[0]
87+
if isinstance(value, Tree) and not value.children: # ['b', Tree(Token('RULE', 'value'), [])]
88+
value = None
89+
return key, value
90+
91+
def key(self, meta: Meta, children: list[Token]) -> str:
92+
return children[0][1:-1]
93+
94+
def start(self, meta: Meta, children: list[dict[str, Any]]) -> dict[str, Any]:
95+
return children[0]
96+
97+
98+
json_parser = Lark(json_grammar, parser="lalr", propagate_positions=True, maybe_placeholders=False)
99+
100+
101+
def extract_json_data(
102+
json_string: str, tokens: list[ChatCompletionTokenLogprob], token_indices: list[int]
103+
) -> dict[str, Any]:
104+
json_parser = Lark(json_grammar, parser="lalr", propagate_positions=True, maybe_placeholders=False)
105+
tree = json_parser.parse(json_string)
106+
extractor = Extractor(tokens, token_indices)
107+
return extractor.transform(tree)
108+
109+
110+
# Transformer that embeds log-probabilities for atomic values as in-line fields in dictionaries
111+
@v_args(meta=True)
112+
class ExtractorInline(Transformer_NonRecursive):
113+
def __init__(self, tokens: list[ChatCompletionTokenLogprob], token_indices: list[int]):
114+
super().__init__()
115+
self.tokens = tokens
116+
self.token_indices = token_indices
117+
118+
def _compute_logprob_sum(self, start: int, end: int) -> float:
119+
token_start = self.token_indices[start]
120+
token_end = self.token_indices[end]
121+
sum_logporb = sum(self.tokens[i].logprob for i in range(token_start, token_end))
122+
return sum_logporb
123+
124+
def number(self, meta: Meta, children: list[Token]) -> HasProb:
125+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
126+
return HasProb(value=float(children[0]), start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum)
127+
128+
def string(self, meta: Meta, children: list[Token]) -> HasProb:
129+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
130+
return HasProb(value=children[0][1:-1], start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum)
131+
132+
def true(self, meta: Meta, children: list[Token]) -> HasProb:
133+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
134+
return HasProb(value=True, start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum)
135+
136+
def false(self, meta: Meta, children: list[Token]) -> HasProb:
137+
logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
138+
return HasProb(value=False, start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum)
139+
140+
def null(self, meta: Meta, children: list[Token]) -> None:
141+
return None
142+
143+
def array(self, meta: Meta, children: list[dict[str, Any] | Any]) -> list[dict[str, Any] | Any]:
144+
return [child.value if isinstance(child, HasProb) else child for child in children]
145+
146+
def object(self, meta: Meta, children: list[tuple[str, Any]]) -> dict[str, Any]:
147+
result = {}
148+
for key, value in children:
149+
if isinstance(value, HasProb):
150+
result[key] = value.value
151+
result[f"{key}_logprob"] = value.logprob
152+
else:
153+
result[key] = value
154+
return result
155+
156+
def pair(self, meta: Meta, children: list[str | Any]) -> tuple[str, Any]:
157+
value = children[1]
158+
key = children[0]
159+
if isinstance(value, Tree) and not value.children: # ['b', Tree(Token('RULE', 'value'), [])]
160+
value = None
161+
return key, value
162+
163+
def key(self, meta: Meta, children: list[Token]) -> str:
164+
return children[0][1:-1]
165+
166+
def start(self, meta: Meta, children: list[dict[str, Any]]) -> dict[str, Any]:
167+
return children[0]
168+
169+
170+
def extract_json_data_inline(
171+
json_string: str, tokens: list[ChatCompletionTokenLogprob], token_indices: list[int]
172+
) -> dict[str, Any]:
173+
json_parser = Lark(json_grammar, parser="lalr", propagate_positions=True, maybe_placeholders=False)
174+
tree = json_parser.parse(json_string)
175+
extractor = ExtractorInline(tokens, token_indices)
176+
return extractor.transform(tree)

structured_logprobs/main.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
66
from pydantic import BaseModel
77

8-
from helpers import extract_json_data, extract_json_data_inline
8+
from structured_logprobs.helpers import extract_json_data, extract_json_data_inline
99

1010
MISSING_LOGPROBS_MESSAGE = "The 'logprobs' field is missing"
1111

@@ -48,16 +48,11 @@ def map_characters_to_token_indices(extracted_data_token: list[ChatCompletionTok
4848
[0, 1, 1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4]
4949
"""
5050

51-
json_output = "".join(token_data.token for token_data in extracted_data_token)
52-
53-
token_indices = [-1] * len(json_output)
54-
current_char_pos = 0
51+
token_indices = []
5552

5653
for token_idx, token_data in enumerate(extracted_data_token):
5754
token_text = token_data.token
58-
for _ in range(len(token_text)):
59-
token_indices[current_char_pos] = token_idx
60-
current_char_pos += 1
55+
token_indices.extend([token_idx] * len(token_text))
6156

6257
return token_indices
6358

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class CalendarEvent(BaseModel):
2121
def chat_completion(pytestconfig) -> ChatCompletion:
2222
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
2323
base_path = Path(pytestconfig.rootdir) # Base directory where pytest was run
24-
schema_path = base_path / "tests" / "questions_json_schema.json"
24+
schema_path = base_path / "tests" / "resources" / "questions_json_schema.json"
2525
with open(schema_path) as f:
2626
schema_content = json.load(f)
2727

@@ -74,7 +74,7 @@ class CalendarEvent(BaseModel):
7474
@pytest.fixture
7575
def simple_parsed_completion(pytestconfig) -> ParsedChatCompletion[CalendarEvent] | None:
7676
base_path = Path(pytestconfig.rootdir) # Base directory where pytest was run
77-
with open(base_path / "tests" / "simple_parsed_completion.json") as f:
77+
with open(base_path / "tests" / "resources" / "simple_parsed_completion.json") as f:
7878
return ParsedChatCompletion[CalendarEvent].model_validate_json(f.read())
7979
return None
8080

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"type": "json_schema",
3+
"json_schema": {
4+
"name": "answears",
5+
"description": "Response to questions in JSON format",
6+
"schema": {
7+
"type": "object",
8+
"properties": {
9+
"capital_of_France": { "type": "string" },
10+
"the_two_nicest_colors": {
11+
"type": "array",
12+
"items": {
13+
"type": "string",
14+
"enum": ["red", "blue", "green", "yellow", "purple"]
15+
}
16+
},
17+
"die_shows": { "type": "integer" }
18+
},
19+
"required": [
20+
"capital_of_France",
21+
"the_two_nicest_colors",
22+
"die_shows"
23+
],
24+
"additionalProperties": false
25+
},
26+
"strict": true
27+
}
28+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"type": "json_schema",
3+
"json_schema": {
4+
"name": "event_extraction",
5+
"description": "Extract details about an event, including participants, event name, and date.",
6+
"schema": {
7+
"type": "object",
8+
"properties": {
9+
"name": { "type": "string" },
10+
"date": { "type": "string" },
11+
"participants": {
12+
"type": "array",
13+
"items": { "type": "string" }
14+
}
15+
},
16+
"required": ["name", "date", "participants"],
17+
"additionalProperties": false
18+
},
19+
"strict": true
20+
}
21+
}

0 commit comments

Comments
 (0)