Skip to content

Commit 95bb092

Browse files
authored
Merge pull request #14 from orq-ai/ORQ-1908-include-langchain-callback-snippet
feat: include langchain callback snippet custom code (ORQ-1908)
2 parents 5558681 + a7a38b9 commit 95bb092

File tree

2 files changed

+268
-0
lines changed

2 files changed

+268
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from datetime import datetime, timezone
2+
from enum import Enum
3+
from pydantic import BaseModel, Field
4+
from typing import Any, Dict, List, Optional, Union
5+
from uuid import UUID
6+
7+
import httpx
8+
9+
try:
10+
from langchain.callbacks.base import BaseCallbackHandler # type: ignore
11+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage # type: ignore
12+
from langchain_core.outputs import LLMResult # type: ignore
13+
except ImportError as exc:
14+
raise ModuleNotFoundError("Please install langchain to use the orq.ai langchain native integration: 'pip install langchain'") from exc
15+
16+
def get_iso_string():
17+
# Get current datetime in UTC, timezone-aware
18+
current_utc_datetime = datetime.now(timezone.utc)
19+
# Format it to ISO 8601 string with 'Z' indicating UTC
20+
return current_utc_datetime.isoformat(timespec="milliseconds").replace(
21+
"+00:00", "Z"
22+
)
23+
24+
class EventType(str, Enum):
25+
LLM = "llm"
26+
27+
class LlmRole(str, Enum):
28+
SYSTEM = "system"
29+
USER = "user"
30+
31+
class LlmUsage(BaseModel):
32+
input_tokens: int
33+
output_tokens: int
34+
35+
class ChoiceMessage(BaseModel):
36+
content: Union[str, List[str]]
37+
role: LlmRole
38+
39+
class Choice(BaseModel):
40+
index: int
41+
message: ChoiceMessage
42+
finish_reason: Optional[str] = None
43+
44+
class LlmEvent(BaseModel):
45+
type: EventType
46+
run_id: str
47+
parameters: Optional[dict] = {}
48+
prompts: Optional[List[str]] = []
49+
messages: List[ChoiceMessage] = []
50+
start_timestamp: str = Field(default_factory=get_iso_string)
51+
end_timestamp: Optional[str] = None
52+
response_choices: List[Choice] = []
53+
usage: Optional[LlmUsage] = None
54+
55+
class OrqClient():
56+
def __init__(self, api_key: str, api_url: str):
57+
self.api_key = api_key
58+
self.api_url = api_url
59+
60+
def log_event(self, event: LlmEvent):
61+
headers = {
62+
"Authorization": f"Bearer {self.api_key}"
63+
}
64+
65+
response = httpx.post(f"{self.api_url}/v2/traces/langchain", headers=headers, json=event.model_dump())
66+
67+
class OrqLangchainCallback(BaseCallbackHandler):
68+
"""Base callback handler that can be used to handle callbacks from langchain."""
69+
70+
def __init__(self, api_key: str, api_url = "https://my.orq.ai"):
71+
self.events: Dict[str, LlmEvent] = {}
72+
self.orq_client = OrqClient(api_key, api_url)
73+
74+
def on_llm_start(
75+
self,
76+
serialized: Dict[str, Any],
77+
prompts: List[str],
78+
*,
79+
run_id: UUID,
80+
metadata: Optional[Dict[str, Any]] = None,
81+
**kwargs: Any,
82+
) -> Any:
83+
self.events[str(run_id)] = LlmEvent(type=EventType.LLM, parameters={
84+
"serialized": serialized,
85+
"metadata": metadata,
86+
"kwargs": kwargs,
87+
}, prompts=prompts, run_id=str(run_id))
88+
89+
def on_chat_model_start(
90+
self,
91+
serialized: Dict[str, Any],
92+
messages: List[List[BaseMessage]],
93+
*,
94+
run_id: UUID,
95+
metadata: Optional[Dict[str, Any]] = None,
96+
**kwargs: Any
97+
) -> Any:
98+
normalize_messages: List[ChoiceMessage] = []
99+
100+
for root_messages in messages:
101+
for message in root_messages:
102+
if isinstance(message, HumanMessage):
103+
normalize_messages.append(ChoiceMessage(role=LlmRole.USER, content=message.content))
104+
elif isinstance(message, SystemMessage):
105+
normalize_messages.append(ChoiceMessage(role=LlmRole.SYSTEM, content=message.content))
106+
107+
self.events[str(run_id)] = LlmEvent(type=EventType.LLM, parameters={
108+
"serialized": serialized,
109+
"metadata": metadata,
110+
"kwargs": kwargs,
111+
}, messages=normalize_messages, run_id=str(run_id))
112+
113+
def on_llm_end(
114+
self,
115+
response: LLMResult,
116+
*,
117+
run_id: UUID,
118+
parent_run_id: Optional[UUID] = None,
119+
**kwargs: Any,
120+
) -> Any:
121+
event: LlmEvent = self.events[str(run_id)]
122+
event.end_timestamp = get_iso_string()
123+
token_usage = response.llm_output['token_usage']
124+
event.usage = LlmUsage(input_tokens=token_usage['prompt_tokens'], output_tokens=token_usage['completion_tokens'])
125+
event.response_choices = []
126+
127+
for index, choice in enumerate(response.generations[0]):
128+
event.response_choices.append(Choice(index=index, message=ChoiceMessage(role=LlmRole.SYSTEM, content=choice.text), finish_reason=choice.generation_info['finish_reason']))
129+
130+
self.orq_client.log_event(event)
131+
132+
__all__ = [
133+
"OrqLangchainCallback"
134+
]

src/orq_ai_sdk/langchain/__init__.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from datetime import datetime, timezone
2+
from enum import Enum
3+
from pydantic import BaseModel, Field
4+
from typing import Any, Dict, List, Optional, Union
5+
from uuid import UUID
6+
7+
import httpx
8+
9+
try:
10+
from langchain.callbacks.base import BaseCallbackHandler # type: ignore
11+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage # type: ignore
12+
from langchain_core.outputs import LLMResult # type: ignore
13+
except ImportError as exc:
14+
raise ModuleNotFoundError("Please install langchain to use the orq.ai langchain native integration: 'pip install langchain'") from exc
15+
16+
def get_iso_string():
17+
# Get current datetime in UTC, timezone-aware
18+
current_utc_datetime = datetime.now(timezone.utc)
19+
# Format it to ISO 8601 string with 'Z' indicating UTC
20+
return current_utc_datetime.isoformat(timespec="milliseconds").replace(
21+
"+00:00", "Z"
22+
)
23+
24+
class EventType(str, Enum):
25+
LLM = "llm"
26+
27+
class LlmRole(str, Enum):
28+
SYSTEM = "system"
29+
USER = "user"
30+
31+
class LlmUsage(BaseModel):
32+
input_tokens: int
33+
output_tokens: int
34+
35+
class ChoiceMessage(BaseModel):
36+
content: Union[str, List[str]]
37+
role: LlmRole
38+
39+
class Choice(BaseModel):
40+
index: int
41+
message: ChoiceMessage
42+
finish_reason: Optional[str] = None
43+
44+
class LlmEvent(BaseModel):
45+
type: EventType
46+
run_id: str
47+
parameters: Optional[dict] = {}
48+
prompts: Optional[List[str]] = []
49+
messages: List[ChoiceMessage] = []
50+
start_timestamp: str = Field(default_factory=get_iso_string)
51+
end_timestamp: Optional[str] = None
52+
response_choices: List[Choice] = []
53+
usage: Optional[LlmUsage] = None
54+
55+
class OrqClient():
56+
def __init__(self, api_key: str, api_url: str):
57+
self.api_key = api_key
58+
self.api_url = api_url
59+
60+
def log_event(self, event: LlmEvent):
61+
headers = {
62+
"Authorization": f"Bearer {self.api_key}"
63+
}
64+
65+
response = httpx.post(f"{self.api_url}/v2/traces/langchain", headers=headers, json=event.model_dump())
66+
67+
class OrqLangchainCallback(BaseCallbackHandler):
68+
"""Base callback handler that can be used to handle callbacks from langchain."""
69+
70+
def __init__(self, api_key: str, api_url = "https://my.orq.ai"):
71+
self.events: Dict[str, LlmEvent] = {}
72+
self.orq_client = OrqClient(api_key, api_url)
73+
74+
def on_llm_start(
75+
self,
76+
serialized: Dict[str, Any],
77+
prompts: List[str],
78+
*,
79+
run_id: UUID,
80+
metadata: Optional[Dict[str, Any]] = None,
81+
**kwargs: Any,
82+
) -> Any:
83+
self.events[str(run_id)] = LlmEvent(type=EventType.LLM, parameters={
84+
"serialized": serialized,
85+
"metadata": metadata,
86+
"kwargs": kwargs,
87+
}, prompts=prompts, run_id=str(run_id))
88+
89+
def on_chat_model_start(
90+
self,
91+
serialized: Dict[str, Any],
92+
messages: List[List[BaseMessage]],
93+
*,
94+
run_id: UUID,
95+
metadata: Optional[Dict[str, Any]] = None,
96+
**kwargs: Any
97+
) -> Any:
98+
normalize_messages: List[ChoiceMessage] = []
99+
100+
for root_messages in messages:
101+
for message in root_messages:
102+
if isinstance(message, HumanMessage):
103+
normalize_messages.append(ChoiceMessage(role=LlmRole.USER, content=message.content))
104+
elif isinstance(message, SystemMessage):
105+
normalize_messages.append(ChoiceMessage(role=LlmRole.SYSTEM, content=message.content))
106+
107+
self.events[str(run_id)] = LlmEvent(type=EventType.LLM, parameters={
108+
"serialized": serialized,
109+
"metadata": metadata,
110+
"kwargs": kwargs,
111+
}, messages=normalize_messages, run_id=str(run_id))
112+
113+
def on_llm_end(
114+
self,
115+
response: LLMResult,
116+
*,
117+
run_id: UUID,
118+
parent_run_id: Optional[UUID] = None,
119+
**kwargs: Any,
120+
) -> Any:
121+
event: LlmEvent = self.events[str(run_id)]
122+
event.end_timestamp = get_iso_string()
123+
token_usage = response.llm_output['token_usage']
124+
event.usage = LlmUsage(input_tokens=token_usage['prompt_tokens'], output_tokens=token_usage['completion_tokens'])
125+
event.response_choices = []
126+
127+
for index, choice in enumerate(response.generations[0]):
128+
event.response_choices.append(Choice(index=index, message=ChoiceMessage(role=LlmRole.SYSTEM, content=choice.text), finish_reason=choice.generation_info['finish_reason']))
129+
130+
self.orq_client.log_event(event)
131+
132+
__all__ = [
133+
"OrqLangchainCallback"
134+
]

0 commit comments

Comments
 (0)